Support passing extra classes to monkeypatch (#953)

* Support passing extra classes to monkeypatch

Closes https://github.com/typeddjango/django-stubs/issues/946#issuecomment-1122895190

* Move extra classes into separate test

* Avoid mutable default

* Fix protocol arguments
This commit is contained in:
henribru
2022-05-12 19:14:59 +02:00
committed by GitHub
parent ccef6779ad
commit 8fe2bd4b9b
2 changed files with 32 additions and 6 deletions

View File

@@ -1,4 +1,4 @@
from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar
from typing import Any, Generic, Iterable, List, Optional, Tuple, Type, TypeVar
from django import VERSION as VERSION
from django.contrib.admin import ModelAdmin
@@ -58,7 +58,7 @@ _need_generic: List[MPGeneric[Any]] = [
]
def monkeypatch() -> None:
def monkeypatch(extra_classes: Optional[Iterable[type]] = None) -> None:
"""Monkey patch django as necessary to work properly with mypy."""
# Add the __class_getitem__ dunder.
@@ -68,6 +68,9 @@ def monkeypatch() -> None:
)
for el in suited_for_this_version:
el.cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls)
if extra_classes:
for cls in extra_classes:
cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls) # type: ignore[attr-defined]
__all__ = ["monkeypatch"]

View File

@@ -1,12 +1,13 @@
import builtins
from contextlib import suppress
from typing import Optional
from typing import Iterable, Optional
import pytest
from _pytest.fixtures import FixtureRequest
from _pytest.monkeypatch import MonkeyPatch
from django.db.models import Model
from django.forms.models import ModelForm
from django.views import View
from typing_extensions import Protocol
import django_stubs_ext
@@ -17,7 +18,9 @@ from django_stubs_ext.patch import _need_generic, _VersionSpec
class _MakeGenericClasses(Protocol):
"""Used to represent a type of ``make_generic_classes`` fixture."""
def __call__(self, django_version: Optional[_VersionSpec] = None) -> None:
def __call__(
self, django_version: Optional[_VersionSpec] = None, extra_classes: Optional[Iterable[type]] = None
) -> None:
...
@@ -26,15 +29,22 @@ def make_generic_classes(
request: FixtureRequest,
monkeypatch: MonkeyPatch,
) -> _MakeGenericClasses:
_extra_classes: list[type] = []
def fin() -> None:
for el in _need_generic:
with suppress(AttributeError):
delattr(el.cls, "__class_getitem__")
for cls in _extra_classes:
with suppress(AttributeError):
delattr(cls, "__class_getitem__")
def factory(django_version: Optional[_VersionSpec] = None) -> None:
def factory(django_version: Optional[_VersionSpec] = None, extra_classes: Optional[Iterable[type]] = None) -> None:
if extra_classes:
_extra_classes.extend(extra_classes)
if django_version is not None:
monkeypatch.setattr(patch, "VERSION", django_version)
django_stubs_ext.monkeypatch()
django_stubs_ext.monkeypatch(extra_classes)
request.addfinalizer(fin)
return factory
@@ -52,6 +62,19 @@ def test_patched_generics(make_generic_classes: _MakeGenericClasses) -> None:
pass
def test_patched_extra_classes_generics(make_generic_classes: _MakeGenericClasses) -> None:
"""Test that the generics actually get patched for extra classes."""
extra_classes = [View]
make_generic_classes(django_version=None, extra_classes=extra_classes)
for cls in extra_classes:
assert cls[type] is cls # type: ignore[misc]
class TestView(View[Model]): # type: ignore[type-arg]
pass
@pytest.mark.parametrize(
"django_version",
[