Adds more types to patch

This commit is contained in:
sobolevn
2020-11-14 20:46:32 +03:00
parent 3e0f144148
commit 517ae648e5
11 changed files with 128 additions and 25 deletions

View File

@@ -1,11 +1,14 @@
from typing import Any, Generic, List, Optional, Type, TypeVar
from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar
import django
from django.contrib.admin import ModelAdmin
from django.contrib.admin.options import BaseModelAdmin
from django.db.models.manager import BaseManager
from django.db.models.query import QuerySet
from django.views.generic.edit import FormMixin
_T = TypeVar("_T")
_VersionSpec = Tuple[int, int]
class MPGeneric(Generic[_T]):
@@ -21,11 +24,15 @@ class MPGeneric(Generic[_T]):
version: Optional[int]
cls: Type[_T]
def __init__(self, cls: Type[_T], version: Optional[int] = None):
def __init__(self, cls: Type[_T], version: Optional[_VersionSpec] = None):
"""Set the data fields, basic constructor."""
self.version = version
self.cls = cls
def __repr__(self) -> str:
"""Better representation in tests and debug."""
return "<MPGeneric: {0}, versions={1}>".format(self.cls, self.version or "all")
# certain django classes need to be generic, but lack the __class_getitem__ dunder needed to
# annotate them: https://github.com/typeddjango/django-stubs/issues/507
@@ -34,13 +41,20 @@ _need_generic: List[MPGeneric[Any]] = [
MPGeneric(ModelAdmin),
MPGeneric(FormMixin),
MPGeneric(BaseModelAdmin),
# These types do have native `__class_getitem__` method since django 3.1:
MPGeneric(QuerySet, (3, 1)),
MPGeneric(BaseManager, (3, 1)),
]
# currently just adds the __class_getitem__ dunder. if more monkeypatching is needed, add it here
def monkeypatch() -> None:
"""Monkey patch django as necessary to work properly with mypy."""
for el in filter(lambda x: django.VERSION[0] == x.version or x.version is None, _need_generic):
suited_for_this_version = filter(
spec.version is None or django.VERSION[:2] <= spec.version,
_need_generic,
)
for el in suited_for_this_version:
el.cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls)

View File

@@ -1,11 +1,49 @@
import pytest
import django_stubs_ext
from django_stubs_ext.monkeypatch import _need_generic
django_stubs_ext.monkeypatch()
from django_stubs_ext.monkeypatch import _need_generic, _VersionSpec, django
def test_patched_generics() -> None:
@pytest.fixture(scope="function")
def make_generic_classes(request, monkeypatch):
def fin():
for el in _need_generic:
delattr(el.cls, "__class_getitem__")
def factory(django_version=None):
if django_version is not None:
monkeypatch.setattr(django, "VERSION", django_version)
django_stubs_ext.monkeypatch()
request.addfinalizer(fin)
return factory
def test_patched_generics(make_generic_classes) -> None:
"""Test that the generics actually get patched."""
make_generic_classes()
for el in _need_generic:
# This only throws an exception if the monkeypatch failed
assert el.cls[type] == el.cls # `type` is arbitrary
if el.version is None:
assert el.cls[type] is el.cls # `type` is arbitrary
@pytest.mark.parametrize(
"django_version",
[
(2, 2),
(3, 0),
(3, 1),
(3, 2),
],
)
def test_patched_version_specific(
django_version: _VersionSpec,
make_generic_classes,
) -> None:
"""Test version speicific types."""
make_generic_classes(django_version)
for el in _need_generic:
if el.version is not None and el.version[:2] <= django_version:
assert el.cls[int] is el.cls