add GenericForeignKey support, remove some false-positives

This commit is contained in:
Maxim Kurnikov
2019-07-18 18:31:37 +03:00
parent bfa77efef5
commit f2e79d3bfb
15 changed files with 111 additions and 62 deletions

View File

@@ -8,8 +8,8 @@ from django.db.models.base import Model
from .config import AppConfig from .config import AppConfig
class Apps: class Apps:
all_models: 'Dict[str, OrderedDict[str, Type[Model]]]' = ... all_models: "Dict[str, OrderedDict[str, Type[Model]]]" = ...
app_configs: 'OrderedDict[str, AppConfig]' = ... app_configs: "OrderedDict[str, AppConfig]" = ...
stored_app_configs: List[Any] = ... stored_app_configs: List[Any] = ...
apps_ready: bool = ... apps_ready: bool = ...
ready_event: threading.Event = ... ready_event: threading.Event = ...

View File

@@ -7,6 +7,7 @@ from django.db.models.fields.related import ForeignObject
from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor
from django.db.models.fields.reverse_related import ForeignObjectRel from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.expressions import Combinable
from django.db.models.fields import Field, PositiveIntegerField from django.db.models.fields import Field, PositiveIntegerField
from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
@@ -14,6 +15,10 @@ from django.db.models.query_utils import FilteredRelation, PathInfo
from django.db.models.sql.where import WhereNode from django.db.models.sql.where import WhereNode
class GenericForeignKey(FieldCacheMixin): class GenericForeignKey(FieldCacheMixin):
# django-stubs implementation only fields
_pyi_private_set_type: Union[Any, Combinable]
_pyi_private_get_type: Any
# attributes
auto_created: bool = ... auto_created: bool = ...
concrete: bool = ... concrete: bool = ...
editable: bool = ... editable: bool = ...
@@ -44,10 +49,8 @@ class GenericForeignKey(FieldCacheMixin):
def get_prefetch_queryset( def get_prefetch_queryset(
self, instances: Union[List[Model], QuerySet], queryset: Optional[QuerySet] = ... self, instances: Union[List[Model], QuerySet], queryset: Optional[QuerySet] = ...
) -> Tuple[List[Model], Callable, Callable, bool, str, bool]: ... ) -> Tuple[List[Model], Callable, Callable, bool, str, bool]: ...
def __get__( def __get__(self, instance: Optional[Model], cls: Type[Model] = ...) -> Optional[Any]: ...
self, instance: Optional[Model], cls: Type[Model] = ... def __set__(self, instance: Model, value: Optional[Any]) -> None: ...
) -> Optional[Union[GenericForeignKey, Model]]: ...
def __set__(self, instance: Model, value: Optional[Model]) -> None: ...
class GenericRel(ForeignObjectRel): class GenericRel(ForeignObjectRel):
field: GenericRelation field: GenericRelation

View File

@@ -1,11 +1,11 @@
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from django.core.management.base import BaseCommand as BaseCommand, CommandError as CommandError from .base import BaseCommand as BaseCommand, CommandError as CommandError
def find_commands(management_dir: str) -> List[str]: ... def find_commands(management_dir: str) -> List[str]: ...
def load_command_class(app_name: str, name: str) -> BaseCommand: ... def load_command_class(app_name: str, name: str) -> BaseCommand: ...
def get_commands() -> Dict[str, str]: ... def get_commands() -> Dict[str, str]: ...
def call_command(command_name: Union[Tuple[str], BaseCommand, str], *args: Any, **options: Any) -> Optional[str]: ... def call_command(command_name: Union[Tuple[str], BaseCommand, str], *args: Any, **options: Any) -> str: ...
class ManagementUtility: class ManagementUtility:
argv: List[str] = ... argv: List[str] = ...

View File

@@ -11,6 +11,7 @@ from .base import (
SerializationError as SerializationError, SerializationError as SerializationError,
DeserializationError as DeserializationError, DeserializationError as DeserializationError,
M2MDeserializationError as M2MDeserializationError, M2MDeserializationError as M2MDeserializationError,
DeserializedObject,
) )
BUILTIN_SERIALIZERS: Any BUILTIN_SERIALIZERS: Any
@@ -27,10 +28,8 @@ def get_serializer(format: str) -> Union[Type[Serializer], BadSerializer]: ...
def get_serializer_formats() -> List[str]: ... def get_serializer_formats() -> List[str]: ...
def get_public_serializer_formats() -> List[str]: ... def get_public_serializer_formats() -> List[str]: ...
def get_deserializer(format: str) -> Union[Callable, Type[Deserializer]]: ... def get_deserializer(format: str) -> Union[Callable, Type[Deserializer]]: ...
def serialize( def serialize(format: str, queryset: Iterable[Model], **options: Any) -> Optional[Union[bytes, str]]: ...
format: str, queryset: Union[Iterator[Any], List[Model], QuerySet], **options: Any def deserialize(format: str, stream_or_string: Any, **options: Any) -> Iterator[DeserializedObject]: ...
) -> Optional[Union[bytes, str]]: ...
def deserialize(format: str, stream_or_string: Any, **options: Any) -> Union[Iterator[Any], Deserializer]: ...
def sort_dependencies( def sort_dependencies(
app_list: Union[Iterable[Tuple[AppConfig, None]], Iterable[Tuple[str, Iterable[Type[Model]]]]] app_list: Union[Iterable[Tuple[AppConfig, None]], Iterable[Tuple[str, Iterable[Type[Model]]]]]
) -> List[Type[Model]]: ... ) -> List[Type[Model]]: ...

View File

@@ -70,8 +70,8 @@ class Deserializer:
def __next__(self) -> None: ... def __next__(self) -> None: ...
class DeserializedObject: class DeserializedObject:
object: Model = ... object: Any = ...
m2m_data: Dict[Any, Any] = ... m2m_data: Dict[str, List[int]] = ...
def __init__(self, obj: Model, m2m_data: Optional[Dict[str, List[int]]] = ...) -> None: ... def __init__(self, obj: Model, m2m_data: Optional[Dict[str, List[int]]] = ...) -> None: ...
def save(self, save_m2m: bool = ..., using: Optional[str] = ..., **kwargs: Any) -> None: ... def save(self, save_m2m: bool = ..., using: Optional[str] = ..., **kwargs: Any) -> None: ...

View File

@@ -6,7 +6,6 @@ from django.core.checks.messages import CheckMessage
from django.db.models.options import Options from django.db.models.options import Options
class ModelBase(type): ... class ModelBase(type): ...
_Self = TypeVar("_Self", bound="Model") _Self = TypeVar("_Self", bound="Model")

View File

@@ -1,7 +1,9 @@
import decimal import decimal
import uuid import uuid
from datetime import date, datetime, time, timedelta from datetime import date, datetime, time, timedelta
from typing import Any, Callable, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union, Sequence from typing import Any, Callable, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union, Sequence, List
from django.core import checks
from django.db.models import Model from django.db.models import Model
from django.core.exceptions import FieldDoesNotExist as FieldDoesNotExist from django.core.exceptions import FieldDoesNotExist as FieldDoesNotExist
@@ -41,6 +43,8 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
null: bool = ... null: bool = ...
editable: bool = ... editable: bool = ...
choices: Optional[_FieldChoices] = ... choices: Optional[_FieldChoices] = ...
db_column: Optional[str]
column: str
def __init__( def __init__(
self, self,
verbose_name: Optional[Union[str, bytes]] = ..., verbose_name: Optional[Union[str, bytes]] = ...,
@@ -86,6 +90,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
) -> Sequence[Union[_Choice, _ChoiceNamedGroup]]: ... ) -> Sequence[Union[_Choice, _ChoiceNamedGroup]]: ...
def has_default(self) -> bool: ... def has_default(self) -> bool: ...
def get_default(self) -> Any: ... def get_default(self) -> Any: ...
def check(self, **kwargs: Any) -> List[checks.Error]: ...
class IntegerField(Field[_ST, _GT]): class IntegerField(Field[_ST, _GT]):
_pyi_private_set_type: Union[float, int, str, Combinable] _pyi_private_set_type: Union[float, int, str, Combinable]

View File

@@ -108,4 +108,6 @@ class Options:
def get_ancestor_link(self, ancestor: Type[Model]) -> Optional[OneToOneField]: ... def get_ancestor_link(self, ancestor: Type[Model]) -> Optional[OneToOneField]: ...
def get_path_to_parent(self, parent: Type[Model]) -> List[PathInfo]: ... def get_path_to_parent(self, parent: Type[Model]) -> List[PathInfo]: ...
def get_path_from_parent(self, parent: Type[Model]) -> List[PathInfo]: ... def get_path_from_parent(self, parent: Type[Model]) -> List[PathInfo]: ...
def get_fields(self, include_parents: bool = ..., include_hidden: bool = ...) -> List[Union[Field, ForeignObjectRel]]: ... def get_fields(
self, include_parents: bool = ..., include_hidden: bool = ...
) -> List[Union[Field, ForeignObjectRel]]: ...

View File

@@ -121,9 +121,11 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
def select_for_update(self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> QuerySet[_T, _Row]: ... def select_for_update(self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> QuerySet[_T, _Row]: ...
def select_related(self, *fields: Any) -> QuerySet[_T, _Row]: ... def select_related(self, *fields: Any) -> QuerySet[_T, _Row]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T, _Row]: ... def prefetch_related(self, *lookups: Any) -> QuerySet[_T, _Row]: ...
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ... # TODO: return type
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[Any, Any]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T, _Row]: ... def order_by(self, *field_names: Any) -> QuerySet[_T, _Row]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T, _Row]: ... def distinct(self, *field_names: Any) -> QuerySet[_T, _Row]: ...
# extra() return type won't be supported any time soon
def extra( def extra(
self, self,
select: Optional[Dict[str, Any]] = ..., select: Optional[Dict[str, Any]] = ...,
@@ -132,7 +134,7 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
tables: Optional[List[str]] = ..., tables: Optional[List[str]] = ...,
order_by: Optional[Sequence[str]] = ..., order_by: Optional[Sequence[str]] = ...,
select_params: Optional[Sequence[Any]] = ..., select_params: Optional[Sequence[Any]] = ...,
) -> QuerySet[_T, _Row]: ... ) -> QuerySet[Any, Any]: ...
def reverse(self) -> QuerySet[_T, _Row]: ... def reverse(self) -> QuerySet[_T, _Row]: ...
def defer(self, *fields: Any) -> QuerySet[_T, _Row]: ... def defer(self, *fields: Any) -> QuerySet[_T, _Row]: ...
def only(self, *fields: Any) -> QuerySet[_T, _Row]: ... def only(self, *fields: Any) -> QuerySet[_T, _Row]: ...

View File

@@ -39,10 +39,10 @@ ungettext = ngettext
def pgettext(context: str, message: str) -> str: ... def pgettext(context: str, message: str) -> str: ...
def npgettext(context: str, singular: str, plural: str, number: int) -> str: ... def npgettext(context: str, singular: str, plural: str, number: int) -> str: ...
gettext_lazy: Any gettext_lazy: Callable[[str], str]
ugettext_lazy: Any ugettext_lazy: Callable[[str], str]
pgettext_lazy: Any pgettext_lazy: Callable[[str], str]
def ngettext_lazy(singular: Any, plural: Any, number: Optional[Any] = ...): ... def ngettext_lazy(singular: Any, plural: Any, number: Optional[Any] = ...): ...

View File

@@ -1,11 +1,13 @@
import os import os
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type, Sequence from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type
from django.core.exceptions import FieldError, FieldDoesNotExist from django.core.exceptions import FieldError
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey, RelatedField from django.db.models.fields.related import ForeignKey, RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.sql.query import Query
from django.utils.functional import cached_property from django.utils.functional import cached_property
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.types import Instance, Type as MypyType from mypy.types import Instance, Type as MypyType
@@ -13,9 +15,6 @@ from pytest_mypy.utils import temp_environ
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
from django.db.models.fields import CharField, Field from django.db.models.fields import CharField, Field
from django.db.models.fields.reverse_related import ForeignObjectRel, ManyToOneRel, ManyToManyRel
from django.db.models.sql.query import Query
from mypy_django_plugin.lib import helpers from mypy_django_plugin.lib import helpers
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -99,7 +98,7 @@ class DjangoFieldsContext:
return Instance(model_info, []) return Instance(model_info, [])
else: else:
return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
is_nullable=is_nullable) is_nullable=is_nullable)
class DjangoLookupsContext: class DjangoLookupsContext:
@@ -184,27 +183,38 @@ class DjangoContext:
raise ValueError('No primary key defined') raise ValueError('No primary key defined')
def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], method: str) -> Dict[str, MypyType]: def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], method: str) -> Dict[str, MypyType]:
from django.contrib.contenttypes.fields import GenericForeignKey
expected_types = {} expected_types = {}
if method == '__init__': # add pk
# add pk primary_key_field = self.get_primary_key_field(model_cls)
primary_key_field = self.get_primary_key_field(model_cls) field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method)
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method) expected_types['pk'] = field_set_type
expected_types['pk'] = field_set_type
for field in self.get_model_fields(model_cls): for field in model_cls._meta.get_fields():
field_name = field.attname if isinstance(field, Field):
field_set_type = self.fields_context.get_field_set_type(api, field, method) field_name = field.attname
expected_types[field_name] = field_set_type field_set_type = self.fields_context.get_field_set_type(api, field, method)
expected_types[field_name] = field_set_type
if isinstance(field, ForeignKey): if isinstance(field, ForeignKey):
field_name = field.name
foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__)
related_model_info = helpers.lookup_class_typeinfo(api, field.related_model)
is_nullable = self.fields_context.get_field_nullability(field, method)
foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info,
'_pyi_private_set_type',
is_nullable=is_nullable)
model_set_type = helpers.convert_any_to_type(foreign_key_set_type,
Instance(related_model_info, []))
expected_types[field_name] = model_set_type
elif isinstance(field, GenericForeignKey):
# it's generic, so cannot set specific model
field_name = field.name field_name = field.name
foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__) gfk_info = helpers.lookup_class_typeinfo(api, field.__class__)
related_model_info = helpers.lookup_class_typeinfo(api, field.related_model) gfk_set_type = helpers.get_private_descriptor_type(gfk_info, '_pyi_private_set_type',
is_nullable = self.fields_context.get_field_nullability(field, method) is_nullable=True)
foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info, expected_types[field_name] = gfk_set_type
'_pyi_private_set_type',
is_nullable=is_nullable)
model_set_type = helpers.convert_any_to_type(foreign_key_set_type,
Instance(related_model_info, []))
expected_types[field_name] = model_set_type
return expected_types return expected_types

View File

@@ -1,6 +1,6 @@
from typing import Optional, Tuple, cast from typing import Optional, Tuple, cast
from mypy.nodes import TypeInfo from mypy.nodes import MypyFile, TypeInfo
from mypy.plugin import FunctionContext from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, Type as MypyType, TypeOfAny from mypy.types import AnyType, CallableType, Instance, Type as MypyType, TypeOfAny
@@ -14,11 +14,13 @@ def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoC
assert isinstance(to_arg_type.ret_type, Instance) assert isinstance(to_arg_type.ret_type, Instance)
return to_arg_type.ret_type.type.fullname() return to_arg_type.ret_type.type.fullname()
outer_model_info = ctx.api.tscope.classes[-1] outer_model_info = ctx.api.scope.active_class()
if not outer_model_info or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
# not inside models.Model class
return None
assert isinstance(outer_model_info, TypeInfo) assert isinstance(outer_model_info, TypeInfo)
to_arg_expr = helpers.get_call_argument_by_name(ctx, 'to') to_arg_expr = helpers.get_call_argument_by_name(ctx, 'to')
model_string = helpers.resolve_string_attribute_value(to_arg_expr, ctx, django_context) model_string = helpers.resolve_string_attribute_value(to_arg_expr, ctx, django_context)
if model_string is None: if model_string is None:
# unresolvable # unresolvable
@@ -28,10 +30,21 @@ def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoC
return outer_model_info.fullname() return outer_model_info.fullname()
if '.' not in model_string: if '.' not in model_string:
# same file class # same file class
current_module = ctx.api.tree model_cls_is_accessible = False
if model_string not in current_module.names: for scope in ctx.api.scope.stack:
if isinstance(scope, (MypyFile, TypeInfo)):
model_class_candidate = scope.names.get(model_string)
model_cls_is_accessible = (model_class_candidate is not None
and isinstance(model_class_candidate.node, TypeInfo)
and model_class_candidate.node.has_base(fullnames.MODEL_CLASS_FULLNAME))
if model_cls_is_accessible:
break
# TODO: FuncItem
if not model_cls_is_accessible:
ctx.api.fail(f'No model {model_string!r} defined in the current module', ctx.context) ctx.api.fail(f'No model {model_string!r} defined in the current module', ctx.context)
return None return None
return outer_model_info.module_name + '.' + model_string return outer_model_info.module_name + '.' + model_string
app_label, model_name = model_string.split('.') app_label, model_name = model_string.split('.')
@@ -88,12 +101,6 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan
default_return_type = ctx.default_return_type default_return_type = ctx.default_return_type
assert isinstance(default_return_type, Instance) assert isinstance(default_return_type, Instance)
# bail out if we're inside migration, not supported yet
active_class = ctx.api.scope.active_class()
if active_class is not None:
if active_class.has_base(fullnames.MIGRATION_CLASS_FULLNAME):
return ctx.default_return_type
if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES): if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
return fill_descriptor_types_for_related_field(ctx, django_context) return fill_descriptor_types_for_related_field(ctx, django_context)

View File

@@ -0,0 +1,22 @@
- case: generic_foreign_key_could_point_to_any_model_and_is_always_optional
main: |
from myapp.models import Tag, User
myuser = User()
Tag(content_object=None)
Tag(content_object=myuser)
Tag.objects.create(content_object=None)
Tag.objects.create(content_object=myuser)
reveal_type(Tag().content_object) # N: Revealed type is 'Union[Any, None]'
installed_apps:
- django.contrib.contenttypes
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
from django.contrib.contenttypes import fields
class User(models.Model):
pass
class Tag(models.Model):
content_object = fields.GenericForeignKey()

View File

@@ -381,10 +381,10 @@
- path: myapp/models.py - path: myapp/models.py
content: | content: |
from django.db import models from django.db import models
class Book(models.Model):
publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE)
class Publisher(models.Model): class Publisher(models.Model):
pass pass
class Book(models.Model):
publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE)
- case: test_foreign_key_field_without_backwards_relation - case: test_foreign_key_field_without_backwards_relation
main: | main: |

View File

@@ -1,7 +1,7 @@
- case: default_manager_create_is_typechecked - case: default_manager_create_is_typechecked
main: | main: |
from myapp.models import User from myapp.models import User
User.objects.create(name='Max', age=10) User.objects.create(pk=1, name='Max', age=10)
User.objects.create(age=[]) # E: Incompatible type for "age" of "User" (got "List[Any]", expected "Union[float, int, str, Combinable]") User.objects.create(age=[]) # E: Incompatible type for "age" of "User" (got "List[Any]", expected "Union[float, int, str, Combinable]")
installed_apps: installed_apps:
- myapp - myapp