diff --git a/mypy_django_plugin/lib/fullnames.py b/mypy_django_plugin/lib/fullnames.py index 798a080..f2bca67 100644 --- a/mypy_django_plugin/lib/fullnames.py +++ b/mypy_django_plugin/lib/fullnames.py @@ -1,3 +1,5 @@ +ABSTRACT_USER_MODEL_FULLNAME = "django.contrib.auth.models.AbstractUser" +PERMISSION_MIXIN_CLASS_FULLNAME = "django.contrib.auth.models.PermissionsMixin" MODEL_CLASS_FULLNAME = "django.db.models.base.Model" FIELD_FULLNAME = "django.db.models.fields.Field" CHAR_FIELD_FULLNAME = "django.db.models.fields.CharField" diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 9ca2a93..2b2365f 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -22,7 +22,7 @@ from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.transformers import fields, forms, init_create, meta, querysets, request, settings from mypy_django_plugin.transformers.managers import create_new_manager_class_from_from_queryset_method -from mypy_django_plugin.transformers.models import process_model_class +from mypy_django_plugin.transformers.models import process_model_class, set_auth_user_model_boolean_fields def transform_model_class(ctx: ClassDefContext, django_context: DjangoContext) -> None: @@ -270,8 +270,12 @@ class NewSemanalDjangoPlugin(Plugin): return partial(settings.get_type_of_settings_attribute, django_context=self.django_context) info = self._get_typeinfo_or_none(class_name) + if info and info.has_base(fullnames.PERMISSION_MIXIN_CLASS_FULLNAME) and attr_name == "is_superuser": + return partial(set_auth_user_model_boolean_fields, django_context=self.django_context) if info and info.has_base(fullnames.HTTPREQUEST_CLASS_FULLNAME) and attr_name == "user": return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context) + if info and info.has_base(fullnames.ABSTRACT_USER_MODEL_FULLNAME) and attr_name in ("is_staff", "is_active"): + return partial(set_auth_user_model_boolean_fields, django_context=self.django_context) return None def get_dynamic_class_hook(self, fullname: str) -> Optional[Callable[[DynamicClassDefContext], None]]: diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 4c1b492..751f895 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -5,7 +5,7 @@ from django.db.models.fields import DateField, DateTimeField from django.db.models.fields.related import ForeignKey from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel from mypy.nodes import ARG_STAR2, Argument, Context, FuncDef, TypeInfo, Var -from mypy.plugin import ClassDefContext +from mypy.plugin import AttributeContext, ClassDefContext from mypy.plugins import common from mypy.semanal import SemanticAnalyzer from mypy.types import AnyType, Instance @@ -355,3 +355,9 @@ def process_model_class(ctx: ClassDefContext, django_context: DjangoContext) -> except helpers.IncompleteDefnException: if not ctx.api.final_iteration: ctx.api.defer() + + +def set_auth_user_model_boolean_fields(ctx: AttributeContext, django_context: DjangoContext) -> MypyType: + boolinfo = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), bool) + assert boolinfo is not None + return Instance(boolinfo, []) diff --git a/tests/typecheck/contrib/auth/test_misc.yml b/tests/typecheck/contrib/auth/test_misc.yml index b9207ce..39fe44b 100644 --- a/tests/typecheck/contrib/auth/test_misc.yml +++ b/tests/typecheck/contrib/auth/test_misc.yml @@ -13,3 +13,25 @@ from django.test import Client get_user(Client()) +- case: test_user_fields + main: | + from typing import Union + from django.contrib.auth.models import AnonymousUser, User + + anonymous: AnonymousUser + anonymous_is_staff: bool = anonymous.is_staff + anonymous_is_active: bool = anonymous.is_active + anonymous_is_superuser: bool = anonymous.is_superuser + anonymous_is_authenticated: bool = anonymous.is_authenticated + + user: User + user_is_staff: bool = user.is_staff + user_is_active: bool = user.is_active + user_is_superuser: bool = user.is_superuser + user_is_authenticated: bool = user.is_authenticated + + union: Union[User, AnonymousUser] + union_is_staff: bool = union.is_staff + union_is_active: bool = union.is_active + union_is_superuser: bool = union.is_superuser + union_is_authenticated: bool = union.is_authenticated diff --git a/tests/typecheck/models/test_contrib_models.yml b/tests/typecheck/models/test_contrib_models.yml index 946d2de..73be029 100644 --- a/tests/typecheck/models/test_contrib_models.yml +++ b/tests/typecheck/models/test_contrib_models.yml @@ -6,8 +6,8 @@ reveal_type(User().first_name) # N: Revealed type is 'builtins.str*' reveal_type(User().last_name) # N: Revealed type is 'builtins.str*' reveal_type(User().email) # N: Revealed type is 'builtins.str*' - reveal_type(User().is_staff) # N: Revealed type is 'builtins.bool*' - reveal_type(User().is_active) # N: Revealed type is 'builtins.bool*' + reveal_type(User().is_staff) # N: Revealed type is 'builtins.bool' + reveal_type(User().is_active) # N: Revealed type is 'builtins.bool' reveal_type(User().date_joined) # N: Revealed type is 'datetime.datetime*' reveal_type(User().last_login) # N: Revealed type is 'Union[datetime.datetime, None]' reveal_type(User().is_authenticated) # N: Revealed type is 'Literal[True]' @@ -22,7 +22,7 @@ reveal_type(Permission().codename) # N: Revealed type is 'builtins.str*' from django.contrib.auth.models import PermissionsMixin - reveal_type(PermissionsMixin().is_superuser) # N: Revealed type is 'builtins.bool*' + reveal_type(PermissionsMixin().is_superuser) # N: Revealed type is 'builtins.bool' from django.contrib.auth.models import Group reveal_type(Group().name) # N: Revealed type is 'builtins.str*'