mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-11 22:41:55 +08:00
Support passing extra classes to monkeypatch (#953)
* Support passing extra classes to monkeypatch Closes https://github.com/typeddjango/django-stubs/issues/946#issuecomment-1122895190 * Move extra classes into separate test * Avoid mutable default * Fix protocol arguments
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar
|
from typing import Any, Generic, Iterable, List, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
from django import VERSION as VERSION
|
from django import VERSION as VERSION
|
||||||
from django.contrib.admin import ModelAdmin
|
from django.contrib.admin import ModelAdmin
|
||||||
@@ -58,7 +58,7 @@ _need_generic: List[MPGeneric[Any]] = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def monkeypatch() -> None:
|
def monkeypatch(extra_classes: Optional[Iterable[type]] = None) -> None:
|
||||||
"""Monkey patch django as necessary to work properly with mypy."""
|
"""Monkey patch django as necessary to work properly with mypy."""
|
||||||
|
|
||||||
# Add the __class_getitem__ dunder.
|
# Add the __class_getitem__ dunder.
|
||||||
@@ -68,6 +68,9 @@ def monkeypatch() -> None:
|
|||||||
)
|
)
|
||||||
for el in suited_for_this_version:
|
for el in suited_for_this_version:
|
||||||
el.cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls)
|
el.cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls)
|
||||||
|
if extra_classes:
|
||||||
|
for cls in extra_classes:
|
||||||
|
cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["monkeypatch"]
|
__all__ = ["monkeypatch"]
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
import builtins
|
import builtins
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from typing import Optional
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from _pytest.fixtures import FixtureRequest
|
from _pytest.fixtures import FixtureRequest
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from django.db.models import Model
|
from django.db.models import Model
|
||||||
from django.forms.models import ModelForm
|
from django.forms.models import ModelForm
|
||||||
|
from django.views import View
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
import django_stubs_ext
|
import django_stubs_ext
|
||||||
@@ -17,7 +18,9 @@ from django_stubs_ext.patch import _need_generic, _VersionSpec
|
|||||||
class _MakeGenericClasses(Protocol):
|
class _MakeGenericClasses(Protocol):
|
||||||
"""Used to represent a type of ``make_generic_classes`` fixture."""
|
"""Used to represent a type of ``make_generic_classes`` fixture."""
|
||||||
|
|
||||||
def __call__(self, django_version: Optional[_VersionSpec] = None) -> None:
|
def __call__(
|
||||||
|
self, django_version: Optional[_VersionSpec] = None, extra_classes: Optional[Iterable[type]] = None
|
||||||
|
) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@@ -26,15 +29,22 @@ def make_generic_classes(
|
|||||||
request: FixtureRequest,
|
request: FixtureRequest,
|
||||||
monkeypatch: MonkeyPatch,
|
monkeypatch: MonkeyPatch,
|
||||||
) -> _MakeGenericClasses:
|
) -> _MakeGenericClasses:
|
||||||
|
_extra_classes: list[type] = []
|
||||||
|
|
||||||
def fin() -> None:
|
def fin() -> None:
|
||||||
for el in _need_generic:
|
for el in _need_generic:
|
||||||
with suppress(AttributeError):
|
with suppress(AttributeError):
|
||||||
delattr(el.cls, "__class_getitem__")
|
delattr(el.cls, "__class_getitem__")
|
||||||
|
for cls in _extra_classes:
|
||||||
|
with suppress(AttributeError):
|
||||||
|
delattr(cls, "__class_getitem__")
|
||||||
|
|
||||||
def factory(django_version: Optional[_VersionSpec] = None) -> None:
|
def factory(django_version: Optional[_VersionSpec] = None, extra_classes: Optional[Iterable[type]] = None) -> None:
|
||||||
|
if extra_classes:
|
||||||
|
_extra_classes.extend(extra_classes)
|
||||||
if django_version is not None:
|
if django_version is not None:
|
||||||
monkeypatch.setattr(patch, "VERSION", django_version)
|
monkeypatch.setattr(patch, "VERSION", django_version)
|
||||||
django_stubs_ext.monkeypatch()
|
django_stubs_ext.monkeypatch(extra_classes)
|
||||||
|
|
||||||
request.addfinalizer(fin)
|
request.addfinalizer(fin)
|
||||||
return factory
|
return factory
|
||||||
@@ -52,6 +62,19 @@ def test_patched_generics(make_generic_classes: _MakeGenericClasses) -> None:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_patched_extra_classes_generics(make_generic_classes: _MakeGenericClasses) -> None:
|
||||||
|
"""Test that the generics actually get patched for extra classes."""
|
||||||
|
extra_classes = [View]
|
||||||
|
|
||||||
|
make_generic_classes(django_version=None, extra_classes=extra_classes)
|
||||||
|
|
||||||
|
for cls in extra_classes:
|
||||||
|
assert cls[type] is cls # type: ignore[misc]
|
||||||
|
|
||||||
|
class TestView(View[Model]): # type: ignore[type-arg]
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"django_version",
|
"django_version",
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user