add GenericForeignKey support, remove some false-positives

This commit is contained in:
Maxim Kurnikov
2019-07-18 18:31:37 +03:00
parent bfa77efef5
commit f2e79d3bfb
15 changed files with 111 additions and 62 deletions

View File

@@ -8,8 +8,8 @@ from django.db.models.base import Model
from .config import AppConfig
class Apps:
all_models: 'Dict[str, OrderedDict[str, Type[Model]]]' = ...
app_configs: 'OrderedDict[str, AppConfig]' = ...
all_models: "Dict[str, OrderedDict[str, Type[Model]]]" = ...
app_configs: "OrderedDict[str, AppConfig]" = ...
stored_app_configs: List[Any] = ...
apps_ready: bool = ...
ready_event: threading.Event = ...

View File

@@ -7,6 +7,7 @@ from django.db.models.fields.related import ForeignObject
from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.expressions import Combinable
from django.db.models.fields import Field, PositiveIntegerField
from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.query import QuerySet
@@ -14,6 +15,10 @@ from django.db.models.query_utils import FilteredRelation, PathInfo
from django.db.models.sql.where import WhereNode
class GenericForeignKey(FieldCacheMixin):
# django-stubs implementation only fields
_pyi_private_set_type: Union[Any, Combinable]
_pyi_private_get_type: Any
# attributes
auto_created: bool = ...
concrete: bool = ...
editable: bool = ...
@@ -44,10 +49,8 @@ class GenericForeignKey(FieldCacheMixin):
def get_prefetch_queryset(
self, instances: Union[List[Model], QuerySet], queryset: Optional[QuerySet] = ...
) -> Tuple[List[Model], Callable, Callable, bool, str, bool]: ...
def __get__(
self, instance: Optional[Model], cls: Type[Model] = ...
) -> Optional[Union[GenericForeignKey, Model]]: ...
def __set__(self, instance: Model, value: Optional[Model]) -> None: ...
def __get__(self, instance: Optional[Model], cls: Type[Model] = ...) -> Optional[Any]: ...
def __set__(self, instance: Model, value: Optional[Any]) -> None: ...
class GenericRel(ForeignObjectRel):
field: GenericRelation

View File

@@ -1,11 +1,11 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Tuple, Union
from django.core.management.base import BaseCommand as BaseCommand, CommandError as CommandError
from .base import BaseCommand as BaseCommand, CommandError as CommandError
def find_commands(management_dir: str) -> List[str]: ...
def load_command_class(app_name: str, name: str) -> BaseCommand: ...
def get_commands() -> Dict[str, str]: ...
def call_command(command_name: Union[Tuple[str], BaseCommand, str], *args: Any, **options: Any) -> Optional[str]: ...
def call_command(command_name: Union[Tuple[str], BaseCommand, str], *args: Any, **options: Any) -> str: ...
class ManagementUtility:
argv: List[str] = ...

View File

@@ -11,6 +11,7 @@ from .base import (
SerializationError as SerializationError,
DeserializationError as DeserializationError,
M2MDeserializationError as M2MDeserializationError,
DeserializedObject,
)
BUILTIN_SERIALIZERS: Any
@@ -27,10 +28,8 @@ def get_serializer(format: str) -> Union[Type[Serializer], BadSerializer]: ...
def get_serializer_formats() -> List[str]: ...
def get_public_serializer_formats() -> List[str]: ...
def get_deserializer(format: str) -> Union[Callable, Type[Deserializer]]: ...
def serialize(
format: str, queryset: Union[Iterator[Any], List[Model], QuerySet], **options: Any
) -> Optional[Union[bytes, str]]: ...
def deserialize(format: str, stream_or_string: Any, **options: Any) -> Union[Iterator[Any], Deserializer]: ...
def serialize(format: str, queryset: Iterable[Model], **options: Any) -> Optional[Union[bytes, str]]: ...
def deserialize(format: str, stream_or_string: Any, **options: Any) -> Iterator[DeserializedObject]: ...
def sort_dependencies(
app_list: Union[Iterable[Tuple[AppConfig, None]], Iterable[Tuple[str, Iterable[Type[Model]]]]]
) -> List[Type[Model]]: ...

View File

@@ -70,8 +70,8 @@ class Deserializer:
def __next__(self) -> None: ...
class DeserializedObject:
object: Model = ...
m2m_data: Dict[Any, Any] = ...
object: Any = ...
m2m_data: Dict[str, List[int]] = ...
def __init__(self, obj: Model, m2m_data: Optional[Dict[str, List[int]]] = ...) -> None: ...
def save(self, save_m2m: bool = ..., using: Optional[str] = ..., **kwargs: Any) -> None: ...

View File

@@ -6,7 +6,6 @@ from django.core.checks.messages import CheckMessage
from django.db.models.options import Options
class ModelBase(type): ...
_Self = TypeVar("_Self", bound="Model")

View File

@@ -1,7 +1,9 @@
import decimal
import uuid
from datetime import date, datetime, time, timedelta
from typing import Any, Callable, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union, Sequence
from typing import Any, Callable, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union, Sequence, List
from django.core import checks
from django.db.models import Model
from django.core.exceptions import FieldDoesNotExist as FieldDoesNotExist
@@ -41,6 +43,8 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
null: bool = ...
editable: bool = ...
choices: Optional[_FieldChoices] = ...
db_column: Optional[str]
column: str
def __init__(
self,
verbose_name: Optional[Union[str, bytes]] = ...,
@@ -86,6 +90,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
) -> Sequence[Union[_Choice, _ChoiceNamedGroup]]: ...
def has_default(self) -> bool: ...
def get_default(self) -> Any: ...
def check(self, **kwargs: Any) -> List[checks.Error]: ...
class IntegerField(Field[_ST, _GT]):
_pyi_private_set_type: Union[float, int, str, Combinable]

View File

@@ -108,4 +108,6 @@ class Options:
def get_ancestor_link(self, ancestor: Type[Model]) -> Optional[OneToOneField]: ...
def get_path_to_parent(self, parent: Type[Model]) -> List[PathInfo]: ...
def get_path_from_parent(self, parent: Type[Model]) -> List[PathInfo]: ...
def get_fields(self, include_parents: bool = ..., include_hidden: bool = ...) -> List[Union[Field, ForeignObjectRel]]: ...
def get_fields(
self, include_parents: bool = ..., include_hidden: bool = ...
) -> List[Union[Field, ForeignObjectRel]]: ...

View File

@@ -121,9 +121,11 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
def select_for_update(self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> QuerySet[_T, _Row]: ...
def select_related(self, *fields: Any) -> QuerySet[_T, _Row]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T, _Row]: ...
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ...
# TODO: return type
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[Any, Any]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T, _Row]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T, _Row]: ...
# extra() return type won't be supported any time soon
def extra(
self,
select: Optional[Dict[str, Any]] = ...,
@@ -132,7 +134,7 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
tables: Optional[List[str]] = ...,
order_by: Optional[Sequence[str]] = ...,
select_params: Optional[Sequence[Any]] = ...,
) -> QuerySet[_T, _Row]: ...
) -> QuerySet[Any, Any]: ...
def reverse(self) -> QuerySet[_T, _Row]: ...
def defer(self, *fields: Any) -> QuerySet[_T, _Row]: ...
def only(self, *fields: Any) -> QuerySet[_T, _Row]: ...

View File

@@ -39,10 +39,10 @@ ungettext = ngettext
def pgettext(context: str, message: str) -> str: ...
def npgettext(context: str, singular: str, plural: str, number: int) -> str: ...
gettext_lazy: Any
gettext_lazy: Callable[[str], str]
ugettext_lazy: Any
pgettext_lazy: Any
ugettext_lazy: Callable[[str], str]
pgettext_lazy: Callable[[str], str]
def ngettext_lazy(singular: Any, plural: Any, number: Optional[Any] = ...): ...

View File

@@ -1,11 +1,13 @@
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type, Sequence
from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type
from django.core.exceptions import FieldError, FieldDoesNotExist
from django.core.exceptions import FieldError
from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey, RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.sql.query import Query
from django.utils.functional import cached_property
from mypy.checker import TypeChecker
from mypy.types import Instance, Type as MypyType
@@ -13,9 +15,6 @@ from pytest_mypy.utils import temp_environ
from django.contrib.postgres.fields import ArrayField
from django.db.models.fields import CharField, Field
from django.db.models.fields.reverse_related import ForeignObjectRel, ManyToOneRel, ManyToManyRel
from django.db.models.sql.query import Query
from mypy_django_plugin.lib import helpers
if TYPE_CHECKING:
@@ -184,14 +183,16 @@ class DjangoContext:
raise ValueError('No primary key defined')
def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], method: str) -> Dict[str, MypyType]:
from django.contrib.contenttypes.fields import GenericForeignKey
expected_types = {}
if method == '__init__':
# add pk
primary_key_field = self.get_primary_key_field(model_cls)
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method)
expected_types['pk'] = field_set_type
for field in self.get_model_fields(model_cls):
for field in model_cls._meta.get_fields():
if isinstance(field, Field):
field_name = field.attname
field_set_type = self.fields_context.get_field_set_type(api, field, method)
expected_types[field_name] = field_set_type
@@ -207,4 +208,13 @@ class DjangoContext:
model_set_type = helpers.convert_any_to_type(foreign_key_set_type,
Instance(related_model_info, []))
expected_types[field_name] = model_set_type
elif isinstance(field, GenericForeignKey):
# it's generic, so cannot set specific model
field_name = field.name
gfk_info = helpers.lookup_class_typeinfo(api, field.__class__)
gfk_set_type = helpers.get_private_descriptor_type(gfk_info, '_pyi_private_set_type',
is_nullable=True)
expected_types[field_name] = gfk_set_type
return expected_types

View File

@@ -1,6 +1,6 @@
from typing import Optional, Tuple, cast
from mypy.nodes import TypeInfo
from mypy.nodes import MypyFile, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, Type as MypyType, TypeOfAny
@@ -14,11 +14,13 @@ def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoC
assert isinstance(to_arg_type.ret_type, Instance)
return to_arg_type.ret_type.type.fullname()
outer_model_info = ctx.api.tscope.classes[-1]
outer_model_info = ctx.api.scope.active_class()
if not outer_model_info or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
# not inside models.Model class
return None
assert isinstance(outer_model_info, TypeInfo)
to_arg_expr = helpers.get_call_argument_by_name(ctx, 'to')
model_string = helpers.resolve_string_attribute_value(to_arg_expr, ctx, django_context)
if model_string is None:
# unresolvable
@@ -28,10 +30,21 @@ def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoC
return outer_model_info.fullname()
if '.' not in model_string:
# same file class
current_module = ctx.api.tree
if model_string not in current_module.names:
model_cls_is_accessible = False
for scope in ctx.api.scope.stack:
if isinstance(scope, (MypyFile, TypeInfo)):
model_class_candidate = scope.names.get(model_string)
model_cls_is_accessible = (model_class_candidate is not None
and isinstance(model_class_candidate.node, TypeInfo)
and model_class_candidate.node.has_base(fullnames.MODEL_CLASS_FULLNAME))
if model_cls_is_accessible:
break
# TODO: FuncItem
if not model_cls_is_accessible:
ctx.api.fail(f'No model {model_string!r} defined in the current module', ctx.context)
return None
return outer_model_info.module_name + '.' + model_string
app_label, model_name = model_string.split('.')
@@ -88,12 +101,6 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan
default_return_type = ctx.default_return_type
assert isinstance(default_return_type, Instance)
# bail out if we're inside migration, not supported yet
active_class = ctx.api.scope.active_class()
if active_class is not None:
if active_class.has_base(fullnames.MIGRATION_CLASS_FULLNAME):
return ctx.default_return_type
if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
return fill_descriptor_types_for_related_field(ctx, django_context)

View File

@@ -0,0 +1,22 @@
- case: generic_foreign_key_could_point_to_any_model_and_is_always_optional
main: |
from myapp.models import Tag, User
myuser = User()
Tag(content_object=None)
Tag(content_object=myuser)
Tag.objects.create(content_object=None)
Tag.objects.create(content_object=myuser)
reveal_type(Tag().content_object) # N: Revealed type is 'Union[Any, None]'
installed_apps:
- django.contrib.contenttypes
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
from django.contrib.contenttypes import fields
class User(models.Model):
pass
class Tag(models.Model):
content_object = fields.GenericForeignKey()

View File

@@ -381,10 +381,10 @@
- path: myapp/models.py
content: |
from django.db import models
class Book(models.Model):
publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE)
class Publisher(models.Model):
pass
class Book(models.Model):
publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE)
- case: test_foreign_key_field_without_backwards_relation
main: |

View File

@@ -1,7 +1,7 @@
- case: default_manager_create_is_typechecked
main: |
from myapp.models import User
User.objects.create(name='Max', age=10)
User.objects.create(pk=1, name='Max', age=10)
User.objects.create(age=[]) # E: Incompatible type for "age" of "User" (got "List[Any]", expected "Union[float, int, str, Combinable]")
installed_apps:
- myapp