From f2e79d3bfbe828335c20628526360eb5014baf8c Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Thu, 18 Jul 2019 18:31:37 +0300 Subject: [PATCH] add GenericForeignKey support, remove some false-positives --- django-stubs/apps/registry.pyi | 4 +- django-stubs/contrib/contenttypes/fields.pyi | 11 ++-- django-stubs/core/management/__init__.pyi | 6 +- django-stubs/core/serializers/__init__.pyi | 7 +-- django-stubs/core/serializers/base.pyi | 4 +- django-stubs/db/models/base.pyi | 1 - django-stubs/db/models/fields/__init__.pyi | 7 ++- django-stubs/db/models/options.pyi | 4 +- django-stubs/db/models/query.pyi | 6 +- django-stubs/utils/translation/__init__.pyi | 6 +- mypy_django_plugin/django/context.py | 60 +++++++++++-------- mypy_django_plugin/transformers/fields.py | 29 +++++---- .../fields/test_generic_foreign_key.yml | 22 +++++++ test-data/typecheck/fields/test_related.yml | 4 +- test-data/typecheck/models/test_create.yml | 2 +- 15 files changed, 111 insertions(+), 62 deletions(-) create mode 100644 test-data/typecheck/fields/test_generic_foreign_key.yml diff --git a/django-stubs/apps/registry.pyi b/django-stubs/apps/registry.pyi index 661f35c..69381ad 100644 --- a/django-stubs/apps/registry.pyi +++ b/django-stubs/apps/registry.pyi @@ -8,8 +8,8 @@ from django.db.models.base import Model from .config import AppConfig class Apps: - all_models: 'Dict[str, OrderedDict[str, Type[Model]]]' = ... - app_configs: 'OrderedDict[str, AppConfig]' = ... + all_models: "Dict[str, OrderedDict[str, Type[Model]]]" = ... + app_configs: "OrderedDict[str, AppConfig]" = ... stored_app_configs: List[Any] = ... apps_ready: bool = ... ready_event: threading.Event = ... diff --git a/django-stubs/contrib/contenttypes/fields.pyi b/django-stubs/contrib/contenttypes/fields.pyi index 1b80e2e..de43919 100644 --- a/django-stubs/contrib/contenttypes/fields.pyi +++ b/django-stubs/contrib/contenttypes/fields.pyi @@ -7,6 +7,7 @@ from django.db.models.fields.related import ForeignObject from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor from django.db.models.fields.reverse_related import ForeignObjectRel +from django.db.models.expressions import Combinable from django.db.models.fields import Field, PositiveIntegerField from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.query import QuerySet @@ -14,6 +15,10 @@ from django.db.models.query_utils import FilteredRelation, PathInfo from django.db.models.sql.where import WhereNode class GenericForeignKey(FieldCacheMixin): + # django-stubs implementation only fields + _pyi_private_set_type: Union[Any, Combinable] + _pyi_private_get_type: Any + # attributes auto_created: bool = ... concrete: bool = ... editable: bool = ... @@ -44,10 +49,8 @@ class GenericForeignKey(FieldCacheMixin): def get_prefetch_queryset( self, instances: Union[List[Model], QuerySet], queryset: Optional[QuerySet] = ... ) -> Tuple[List[Model], Callable, Callable, bool, str, bool]: ... - def __get__( - self, instance: Optional[Model], cls: Type[Model] = ... - ) -> Optional[Union[GenericForeignKey, Model]]: ... - def __set__(self, instance: Model, value: Optional[Model]) -> None: ... + def __get__(self, instance: Optional[Model], cls: Type[Model] = ...) -> Optional[Any]: ... + def __set__(self, instance: Model, value: Optional[Any]) -> None: ... class GenericRel(ForeignObjectRel): field: GenericRelation diff --git a/django-stubs/core/management/__init__.pyi b/django-stubs/core/management/__init__.pyi index 92d994e..2b5bef2 100644 --- a/django-stubs/core/management/__init__.pyi +++ b/django-stubs/core/management/__init__.pyi @@ -1,11 +1,11 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union -from django.core.management.base import BaseCommand as BaseCommand, CommandError as CommandError +from .base import BaseCommand as BaseCommand, CommandError as CommandError def find_commands(management_dir: str) -> List[str]: ... def load_command_class(app_name: str, name: str) -> BaseCommand: ... def get_commands() -> Dict[str, str]: ... -def call_command(command_name: Union[Tuple[str], BaseCommand, str], *args: Any, **options: Any) -> Optional[str]: ... +def call_command(command_name: Union[Tuple[str], BaseCommand, str], *args: Any, **options: Any) -> str: ... class ManagementUtility: argv: List[str] = ... diff --git a/django-stubs/core/serializers/__init__.pyi b/django-stubs/core/serializers/__init__.pyi index d9eab40..d7a7133 100644 --- a/django-stubs/core/serializers/__init__.pyi +++ b/django-stubs/core/serializers/__init__.pyi @@ -11,6 +11,7 @@ from .base import ( SerializationError as SerializationError, DeserializationError as DeserializationError, M2MDeserializationError as M2MDeserializationError, + DeserializedObject, ) BUILTIN_SERIALIZERS: Any @@ -27,10 +28,8 @@ def get_serializer(format: str) -> Union[Type[Serializer], BadSerializer]: ... def get_serializer_formats() -> List[str]: ... def get_public_serializer_formats() -> List[str]: ... def get_deserializer(format: str) -> Union[Callable, Type[Deserializer]]: ... -def serialize( - format: str, queryset: Union[Iterator[Any], List[Model], QuerySet], **options: Any -) -> Optional[Union[bytes, str]]: ... -def deserialize(format: str, stream_or_string: Any, **options: Any) -> Union[Iterator[Any], Deserializer]: ... +def serialize(format: str, queryset: Iterable[Model], **options: Any) -> Optional[Union[bytes, str]]: ... +def deserialize(format: str, stream_or_string: Any, **options: Any) -> Iterator[DeserializedObject]: ... def sort_dependencies( app_list: Union[Iterable[Tuple[AppConfig, None]], Iterable[Tuple[str, Iterable[Type[Model]]]]] ) -> List[Type[Model]]: ... diff --git a/django-stubs/core/serializers/base.pyi b/django-stubs/core/serializers/base.pyi index 1057130..ed57d8e 100644 --- a/django-stubs/core/serializers/base.pyi +++ b/django-stubs/core/serializers/base.pyi @@ -70,8 +70,8 @@ class Deserializer: def __next__(self) -> None: ... class DeserializedObject: - object: Model = ... - m2m_data: Dict[Any, Any] = ... + object: Any = ... + m2m_data: Dict[str, List[int]] = ... def __init__(self, obj: Model, m2m_data: Optional[Dict[str, List[int]]] = ...) -> None: ... def save(self, save_m2m: bool = ..., using: Optional[str] = ..., **kwargs: Any) -> None: ... diff --git a/django-stubs/db/models/base.pyi b/django-stubs/db/models/base.pyi index 2d68e65..62fb57d 100644 --- a/django-stubs/db/models/base.pyi +++ b/django-stubs/db/models/base.pyi @@ -6,7 +6,6 @@ from django.core.checks.messages import CheckMessage from django.db.models.options import Options - class ModelBase(type): ... _Self = TypeVar("_Self", bound="Model") diff --git a/django-stubs/db/models/fields/__init__.pyi b/django-stubs/db/models/fields/__init__.pyi index cbbde04..7b7f174 100644 --- a/django-stubs/db/models/fields/__init__.pyi +++ b/django-stubs/db/models/fields/__init__.pyi @@ -1,7 +1,9 @@ import decimal import uuid from datetime import date, datetime, time, timedelta -from typing import Any, Callable, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union, Sequence +from typing import Any, Callable, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union, Sequence, List + +from django.core import checks from django.db.models import Model from django.core.exceptions import FieldDoesNotExist as FieldDoesNotExist @@ -41,6 +43,8 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]): null: bool = ... editable: bool = ... choices: Optional[_FieldChoices] = ... + db_column: Optional[str] + column: str def __init__( self, verbose_name: Optional[Union[str, bytes]] = ..., @@ -86,6 +90,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]): ) -> Sequence[Union[_Choice, _ChoiceNamedGroup]]: ... def has_default(self) -> bool: ... def get_default(self) -> Any: ... + def check(self, **kwargs: Any) -> List[checks.Error]: ... class IntegerField(Field[_ST, _GT]): _pyi_private_set_type: Union[float, int, str, Combinable] diff --git a/django-stubs/db/models/options.pyi b/django-stubs/db/models/options.pyi index a0ba4aa..1f51ba0 100644 --- a/django-stubs/db/models/options.pyi +++ b/django-stubs/db/models/options.pyi @@ -108,4 +108,6 @@ class Options: def get_ancestor_link(self, ancestor: Type[Model]) -> Optional[OneToOneField]: ... def get_path_to_parent(self, parent: Type[Model]) -> List[PathInfo]: ... def get_path_from_parent(self, parent: Type[Model]) -> List[PathInfo]: ... - def get_fields(self, include_parents: bool = ..., include_hidden: bool = ...) -> List[Union[Field, ForeignObjectRel]]: ... + def get_fields( + self, include_parents: bool = ..., include_hidden: bool = ... + ) -> List[Union[Field, ForeignObjectRel]]: ... diff --git a/django-stubs/db/models/query.pyi b/django-stubs/db/models/query.pyi index e53adda..649a32a 100644 --- a/django-stubs/db/models/query.pyi +++ b/django-stubs/db/models/query.pyi @@ -121,9 +121,11 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized): def select_for_update(self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> QuerySet[_T, _Row]: ... def select_related(self, *fields: Any) -> QuerySet[_T, _Row]: ... def prefetch_related(self, *lookups: Any) -> QuerySet[_T, _Row]: ... - def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ... + # TODO: return type + def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[Any, Any]: ... def order_by(self, *field_names: Any) -> QuerySet[_T, _Row]: ... def distinct(self, *field_names: Any) -> QuerySet[_T, _Row]: ... + # extra() return type won't be supported any time soon def extra( self, select: Optional[Dict[str, Any]] = ..., @@ -132,7 +134,7 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized): tables: Optional[List[str]] = ..., order_by: Optional[Sequence[str]] = ..., select_params: Optional[Sequence[Any]] = ..., - ) -> QuerySet[_T, _Row]: ... + ) -> QuerySet[Any, Any]: ... def reverse(self) -> QuerySet[_T, _Row]: ... def defer(self, *fields: Any) -> QuerySet[_T, _Row]: ... def only(self, *fields: Any) -> QuerySet[_T, _Row]: ... diff --git a/django-stubs/utils/translation/__init__.pyi b/django-stubs/utils/translation/__init__.pyi index cc514d5..882dbb6 100644 --- a/django-stubs/utils/translation/__init__.pyi +++ b/django-stubs/utils/translation/__init__.pyi @@ -39,10 +39,10 @@ ungettext = ngettext def pgettext(context: str, message: str) -> str: ... def npgettext(context: str, singular: str, plural: str, number: int) -> str: ... -gettext_lazy: Any +gettext_lazy: Callable[[str], str] -ugettext_lazy: Any -pgettext_lazy: Any +ugettext_lazy: Callable[[str], str] +pgettext_lazy: Callable[[str], str] def ngettext_lazy(singular: Any, plural: Any, number: Optional[Any] = ...): ... diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index ea2fc8c..fd1f1f2 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -1,11 +1,13 @@ import os from collections import defaultdict from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type, Sequence +from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type -from django.core.exceptions import FieldError, FieldDoesNotExist +from django.core.exceptions import FieldError from django.db.models.base import Model from django.db.models.fields.related import ForeignKey, RelatedField +from django.db.models.fields.reverse_related import ForeignObjectRel +from django.db.models.sql.query import Query from django.utils.functional import cached_property from mypy.checker import TypeChecker from mypy.types import Instance, Type as MypyType @@ -13,9 +15,6 @@ from pytest_mypy.utils import temp_environ from django.contrib.postgres.fields import ArrayField from django.db.models.fields import CharField, Field -from django.db.models.fields.reverse_related import ForeignObjectRel, ManyToOneRel, ManyToManyRel - -from django.db.models.sql.query import Query from mypy_django_plugin.lib import helpers if TYPE_CHECKING: @@ -99,7 +98,7 @@ class DjangoFieldsContext: return Instance(model_info, []) else: return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', - is_nullable=is_nullable) + is_nullable=is_nullable) class DjangoLookupsContext: @@ -184,27 +183,38 @@ class DjangoContext: raise ValueError('No primary key defined') def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], method: str) -> Dict[str, MypyType]: + from django.contrib.contenttypes.fields import GenericForeignKey + expected_types = {} - if method == '__init__': - # add pk - primary_key_field = self.get_primary_key_field(model_cls) - field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method) - expected_types['pk'] = field_set_type + # add pk + primary_key_field = self.get_primary_key_field(model_cls) + field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method) + expected_types['pk'] = field_set_type - for field in self.get_model_fields(model_cls): - field_name = field.attname - field_set_type = self.fields_context.get_field_set_type(api, field, method) - expected_types[field_name] = field_set_type + for field in model_cls._meta.get_fields(): + if isinstance(field, Field): + field_name = field.attname + field_set_type = self.fields_context.get_field_set_type(api, field, method) + expected_types[field_name] = field_set_type - if isinstance(field, ForeignKey): + if isinstance(field, ForeignKey): + field_name = field.name + foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__) + related_model_info = helpers.lookup_class_typeinfo(api, field.related_model) + is_nullable = self.fields_context.get_field_nullability(field, method) + foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info, + '_pyi_private_set_type', + is_nullable=is_nullable) + model_set_type = helpers.convert_any_to_type(foreign_key_set_type, + Instance(related_model_info, [])) + expected_types[field_name] = model_set_type + + elif isinstance(field, GenericForeignKey): + # it's generic, so cannot set specific model field_name = field.name - foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__) - related_model_info = helpers.lookup_class_typeinfo(api, field.related_model) - is_nullable = self.fields_context.get_field_nullability(field, method) - foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info, - '_pyi_private_set_type', - is_nullable=is_nullable) - model_set_type = helpers.convert_any_to_type(foreign_key_set_type, - Instance(related_model_info, [])) - expected_types[field_name] = model_set_type + gfk_info = helpers.lookup_class_typeinfo(api, field.__class__) + gfk_set_type = helpers.get_private_descriptor_type(gfk_info, '_pyi_private_set_type', + is_nullable=True) + expected_types[field_name] = gfk_set_type + return expected_types diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index ec3a7e9..d294043 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -1,6 +1,6 @@ from typing import Optional, Tuple, cast -from mypy.nodes import TypeInfo +from mypy.nodes import MypyFile, TypeInfo from mypy.plugin import FunctionContext from mypy.types import AnyType, CallableType, Instance, Type as MypyType, TypeOfAny @@ -14,11 +14,13 @@ def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoC assert isinstance(to_arg_type.ret_type, Instance) return to_arg_type.ret_type.type.fullname() - outer_model_info = ctx.api.tscope.classes[-1] + outer_model_info = ctx.api.scope.active_class() + if not outer_model_info or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME): + # not inside models.Model class + return None assert isinstance(outer_model_info, TypeInfo) to_arg_expr = helpers.get_call_argument_by_name(ctx, 'to') - model_string = helpers.resolve_string_attribute_value(to_arg_expr, ctx, django_context) if model_string is None: # unresolvable @@ -28,10 +30,21 @@ def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoC return outer_model_info.fullname() if '.' not in model_string: # same file class - current_module = ctx.api.tree - if model_string not in current_module.names: + model_cls_is_accessible = False + for scope in ctx.api.scope.stack: + if isinstance(scope, (MypyFile, TypeInfo)): + model_class_candidate = scope.names.get(model_string) + model_cls_is_accessible = (model_class_candidate is not None + and isinstance(model_class_candidate.node, TypeInfo) + and model_class_candidate.node.has_base(fullnames.MODEL_CLASS_FULLNAME)) + if model_cls_is_accessible: + break + # TODO: FuncItem + + if not model_cls_is_accessible: ctx.api.fail(f'No model {model_string!r} defined in the current module', ctx.context) return None + return outer_model_info.module_name + '.' + model_string app_label, model_name = model_string.split('.') @@ -88,12 +101,6 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan default_return_type = ctx.default_return_type assert isinstance(default_return_type, Instance) - # bail out if we're inside migration, not supported yet - active_class = ctx.api.scope.active_class() - if active_class is not None: - if active_class.has_base(fullnames.MIGRATION_CLASS_FULLNAME): - return ctx.default_return_type - if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES): return fill_descriptor_types_for_related_field(ctx, django_context) diff --git a/test-data/typecheck/fields/test_generic_foreign_key.yml b/test-data/typecheck/fields/test_generic_foreign_key.yml new file mode 100644 index 0000000..a05f5e3 --- /dev/null +++ b/test-data/typecheck/fields/test_generic_foreign_key.yml @@ -0,0 +1,22 @@ +- case: generic_foreign_key_could_point_to_any_model_and_is_always_optional + main: | + from myapp.models import Tag, User + myuser = User() + Tag(content_object=None) + Tag(content_object=myuser) + Tag.objects.create(content_object=None) + Tag.objects.create(content_object=myuser) + reveal_type(Tag().content_object) # N: Revealed type is 'Union[Any, None]' + installed_apps: + - django.contrib.contenttypes + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + from django.contrib.contenttypes import fields + class User(models.Model): + pass + class Tag(models.Model): + content_object = fields.GenericForeignKey() \ No newline at end of file diff --git a/test-data/typecheck/fields/test_related.yml b/test-data/typecheck/fields/test_related.yml index da574ae..e22a60c 100644 --- a/test-data/typecheck/fields/test_related.yml +++ b/test-data/typecheck/fields/test_related.yml @@ -381,10 +381,10 @@ - path: myapp/models.py content: | from django.db import models - class Book(models.Model): - publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE) class Publisher(models.Model): pass + class Book(models.Model): + publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE) - case: test_foreign_key_field_without_backwards_relation main: | diff --git a/test-data/typecheck/models/test_create.yml b/test-data/typecheck/models/test_create.yml index cf236be..5b24716 100644 --- a/test-data/typecheck/models/test_create.yml +++ b/test-data/typecheck/models/test_create.yml @@ -1,7 +1,7 @@ - case: default_manager_create_is_typechecked main: | from myapp.models import User - User.objects.create(name='Max', age=10) + User.objects.create(pk=1, name='Max', age=10) User.objects.create(age=[]) # E: Incompatible type for "age" of "User" (got "List[Any]", expected "Union[float, int, str, Combinable]") installed_apps: - myapp