From 8fe2bd4b9b6c6b13d893d33eddb12b9118d5aa94 Mon Sep 17 00:00:00 2001 From: henribru <6639509+henribru@users.noreply.github.com> Date: Thu, 12 May 2022 19:14:59 +0200 Subject: [PATCH] Support passing extra classes to monkeypatch (#953) * Support passing extra classes to monkeypatch Closes https://github.com/typeddjango/django-stubs/issues/946#issuecomment-1122895190 * Move extra classes into separate test * Avoid mutable default * Fix protocol arguments --- django_stubs_ext/django_stubs_ext/patch.py | 7 +++-- django_stubs_ext/tests/test_monkeypatching.py | 31 ++++++++++++++++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/django_stubs_ext/django_stubs_ext/patch.py b/django_stubs_ext/django_stubs_ext/patch.py index 7f6dd71..17d50cc 100644 --- a/django_stubs_ext/django_stubs_ext/patch.py +++ b/django_stubs_ext/django_stubs_ext/patch.py @@ -1,4 +1,4 @@ -from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar +from typing import Any, Generic, Iterable, List, Optional, Tuple, Type, TypeVar from django import VERSION as VERSION from django.contrib.admin import ModelAdmin @@ -58,7 +58,7 @@ _need_generic: List[MPGeneric[Any]] = [ ] -def monkeypatch() -> None: +def monkeypatch(extra_classes: Optional[Iterable[type]] = None) -> None: """Monkey patch django as necessary to work properly with mypy.""" # Add the __class_getitem__ dunder. @@ -68,6 +68,9 @@ def monkeypatch() -> None: ) for el in suited_for_this_version: el.cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls) + if extra_classes: + for cls in extra_classes: + cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls) # type: ignore[attr-defined] __all__ = ["monkeypatch"] diff --git a/django_stubs_ext/tests/test_monkeypatching.py b/django_stubs_ext/tests/test_monkeypatching.py index 4e3c162..c6f9bf9 100644 --- a/django_stubs_ext/tests/test_monkeypatching.py +++ b/django_stubs_ext/tests/test_monkeypatching.py @@ -1,12 +1,13 @@ import builtins from contextlib import suppress -from typing import Optional +from typing import Iterable, Optional import pytest from _pytest.fixtures import FixtureRequest from _pytest.monkeypatch import MonkeyPatch from django.db.models import Model from django.forms.models import ModelForm +from django.views import View from typing_extensions import Protocol import django_stubs_ext @@ -17,7 +18,9 @@ from django_stubs_ext.patch import _need_generic, _VersionSpec class _MakeGenericClasses(Protocol): """Used to represent a type of ``make_generic_classes`` fixture.""" - def __call__(self, django_version: Optional[_VersionSpec] = None) -> None: + def __call__( + self, django_version: Optional[_VersionSpec] = None, extra_classes: Optional[Iterable[type]] = None + ) -> None: ... @@ -26,15 +29,22 @@ def make_generic_classes( request: FixtureRequest, monkeypatch: MonkeyPatch, ) -> _MakeGenericClasses: + _extra_classes: list[type] = [] + def fin() -> None: for el in _need_generic: with suppress(AttributeError): delattr(el.cls, "__class_getitem__") + for cls in _extra_classes: + with suppress(AttributeError): + delattr(cls, "__class_getitem__") - def factory(django_version: Optional[_VersionSpec] = None) -> None: + def factory(django_version: Optional[_VersionSpec] = None, extra_classes: Optional[Iterable[type]] = None) -> None: + if extra_classes: + _extra_classes.extend(extra_classes) if django_version is not None: monkeypatch.setattr(patch, "VERSION", django_version) - django_stubs_ext.monkeypatch() + django_stubs_ext.monkeypatch(extra_classes) request.addfinalizer(fin) return factory @@ -52,6 +62,19 @@ def test_patched_generics(make_generic_classes: _MakeGenericClasses) -> None: pass +def test_patched_extra_classes_generics(make_generic_classes: _MakeGenericClasses) -> None: + """Test that the generics actually get patched for extra classes.""" + extra_classes = [View] + + make_generic_classes(django_version=None, extra_classes=extra_classes) + + for cls in extra_classes: + assert cls[type] is cls # type: ignore[misc] + + class TestView(View[Model]): # type: ignore[type-arg] + pass + + @pytest.mark.parametrize( "django_version", [