diff --git a/django-stubs/contrib/auth/base_user.pyi b/django-stubs/contrib/auth/base_user.pyi index f9179d2..945a60e 100644 --- a/django-stubs/contrib/auth/base_user.pyi +++ b/django-stubs/contrib/auth/base_user.pyi @@ -1,12 +1,16 @@ -from typing import Any, Optional, Tuple, List, overload +from typing import Any, Optional, Tuple, List, overload, TypeVar + +from django.db.models.base import Model from django.db import models -class BaseUserManager(models.Manager): +_T = TypeVar('_T', bound=Model) + +class BaseUserManager(models.Manager[_T]): @classmethod def normalize_email(cls, email: Optional[str]) -> str: ... def make_random_password(self, length: int = ..., allowed_chars: str = ...) -> str: ... - def get_by_natural_key(self, username: Optional[str]) -> AbstractBaseUser: ... + def get_by_natural_key(self, username: Optional[str]) -> _T: ... class AbstractBaseUser(models.Model): password: models.CharField = ... diff --git a/django-stubs/contrib/auth/models.pyi b/django-stubs/contrib/auth/models.pyi index c642c79..6c9d090 100644 --- a/django-stubs/contrib/auth/models.pyi +++ b/django-stubs/contrib/auth/models.pyi @@ -1,7 +1,8 @@ -from typing import Any, Collection, Optional, Set, Tuple, Type, Union +from typing import Any, List, Optional, Set, Tuple, Type, Union, TypeVar from django.contrib.auth.base_user import AbstractBaseUser as AbstractBaseUser, BaseUserManager as BaseUserManager from django.contrib.contenttypes.models import ContentType +from django.db.models.base import Model from django.db.models.manager import EmptyManager from django.contrib.auth.validators import UnicodeUsernameValidator @@ -27,13 +28,15 @@ class Group(models.Model): permissions: models.ManyToManyField = models.ManyToManyField(Permission) def natural_key(self): ... -class UserManager(BaseUserManager): +_T = TypeVar('_T', bound=Model) + +class UserManager(BaseUserManager[_T]): def create_user( self, username: str, email: Optional[str] = ..., password: Optional[str] = ..., **extra_fields: Any - ) -> AbstractBaseUser: ... + ) -> _T: ... def create_superuser( self, username: str, email: Optional[str], password: Optional[str], **extra_fields: Any - ) -> AbstractBaseUser: ... + ) -> _T: ... class PermissionsMixin(models.Model): is_superuser: models.BooleanField = ... diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index cdcd707..a91393d 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -1,22 +1,19 @@ import os from collections import defaultdict from contextlib import contextmanager -from typing import ( - TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Type, -) +from typing import (Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type) -from django.contrib.postgres.fields import ArrayField from django.core.exceptions import FieldError from django.db.models.base import Model -from django.db.models.fields import CharField, Field, AutoField from django.db.models.fields.related import ForeignKey, RelatedField from django.db.models.fields.reverse_related import ForeignObjectRel from django.db.models.sql.query import Query from django.utils.functional import cached_property from mypy.checker import TypeChecker -from mypy.types import Instance -from mypy.types import Type as MypyType +from mypy.types import Instance, Type as MypyType +from django.contrib.postgres.fields import ArrayField +from django.db.models.fields import AutoField, CharField, Field from mypy_django_plugin.lib import helpers if TYPE_CHECKING: @@ -210,13 +207,19 @@ class DjangoContext: if isinstance(field, ForeignKey): field_name = field.name foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__) - related_model_info = helpers.lookup_class_typeinfo(api, field.related_model) + + related_model = field.related_model + if related_model._meta.proxy_for_model: + related_model = field.related_model._meta.proxy_for_model + + related_model_info = helpers.lookup_class_typeinfo(api, related_model) is_nullable = self.fields_context.get_field_nullability(field, method) foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info, '_pyi_private_set_type', is_nullable=is_nullable) model_set_type = helpers.convert_any_to_type(foreign_key_set_type, Instance(related_model_info, [])) + expected_types[field_name] = model_set_type elif isinstance(field, GenericForeignKey): diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 47e5838..5dddd72 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -1,17 +1,13 @@ from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union from mypy import checker from mypy.checker import TypeChecker from mypy.mro import calculate_mro -from mypy.nodes import ( - GDEF, MDEF, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, - SymbolTableNode, TypeInfo, Var, -) +from mypy.nodes import (Block, ClassDef, Expression, GDEF, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, + SymbolTable, SymbolTableNode, TypeInfo, Var) from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext -from mypy.types import AnyType, Instance, NoneTyp, TupleType -from mypy.types import Type as MypyType -from mypy.types import TypedDictType, TypeOfAny, UnionType +from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType if TYPE_CHECKING: from mypy_django_plugin.django.context import DjangoContext diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index f393ea9..ee77ed1 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -39,15 +39,20 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context return AnyType(TypeOfAny.from_error) assert isinstance(current_field, RelatedField) - referred_to_typeinfo = helpers.lookup_class_typeinfo(ctx.api, current_field.related_model) - referred_to_type = Instance(referred_to_typeinfo, []) + + related_model = related_model_to_set = current_field.related_model + if related_model_to_set._meta.proxy_for_model: + related_model_to_set = related_model._meta.proxy_for_model + + related_model_info = helpers.lookup_class_typeinfo(ctx.api, related_model) + related_model_to_set_info = helpers.lookup_class_typeinfo(ctx.api, related_model_to_set) default_related_field_type = set_descriptor_types_for_field(ctx) # replace Any with referred_to_type - args = [] - for default_arg in default_related_field_type.args: - args.append(helpers.convert_any_to_type(default_arg, referred_to_type)) - + args = [ + helpers.convert_any_to_type(default_related_field_type.args[0], Instance(related_model_to_set_info, [])), + helpers.convert_any_to_type(default_related_field_type.args[1], Instance(related_model_info, [])), + ] return helpers.reparametrize_instance(default_related_field_type, new_args=args) diff --git a/test-data/typecheck/models/test_proxy_models.yml b/test-data/typecheck/models/test_proxy_models.yml new file mode 100644 index 0000000..ba46845 --- /dev/null +++ b/test-data/typecheck/models/test_proxy_models.yml @@ -0,0 +1,21 @@ +- case: foreign_key_to_proxy_model_accepts_first_non_proxy_model + main: | + from myapp.models import Blog, Publisher, PublisherProxy + Blog(publisher=Publisher()) + Blog.objects.create(publisher=Publisher()) + Blog().publisher = Publisher() + reveal_type(Blog().publisher) # N: Revealed type is 'myapp.models.PublisherProxy*' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class Publisher(models.Model): + pass + class PublisherProxy(Publisher): + class Meta: + proxy = True + class Blog(models.Model): + publisher = models.ForeignKey(to=PublisherProxy, on_delete=models.CASCADE)