add support for typechecking of filter/get/exclude arguments (#183)

* add support for typechecking of filter/get/exclude arguments

* linting
This commit is contained in:
Maxim Kurnikov
2019-09-30 03:05:40 +03:00
committed by GitHub
parent 4d4b0003bd
commit 02bdf5be95
10 changed files with 451 additions and 48 deletions

View File

@@ -2,7 +2,7 @@ import os
from collections import defaultdict
from contextlib import contextmanager
from typing import (
TYPE_CHECKING, Dict, Iterator, Optional, Set, Tuple, Type, Union,
TYPE_CHECKING, Dict, Iterable, Iterator, Optional, Set, Tuple, Type, Union,
)
from django.core.exceptions import FieldError
@@ -11,14 +11,16 @@ from django.db.models.base import Model
from django.db.models.fields import AutoField, CharField, Field
from django.db.models.fields.related import ForeignKey, RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.lookups import Exact
from django.db.models.sql.query import Query
from django.utils.functional import cached_property
from mypy.checker import TypeChecker
from mypy.plugin import MethodContext
from mypy.types import AnyType, Instance
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny
from mypy.types import TypeOfAny, UnionType
from mypy_django_plugin.lib import helpers
from mypy_django_plugin.lib import fullnames, helpers
try:
from django.contrib.postgres.fields import ArrayField
@@ -153,33 +155,87 @@ class DjangoFieldsContext:
return related_model_cls
class LookupsAreUnsupported(Exception):
pass
class DjangoLookupsContext:
def __init__(self, django_context: 'DjangoContext'):
self.django_context = django_context
def resolve_lookup(self, model_cls: Type[Model], lookup: str) -> Field:
def _resolve_field_from_parts(self, field_parts: Iterable[str], model_cls: Type[Model]) -> Field:
currently_observed_model = model_cls
field = None
for field_part in field_parts:
if field_part == 'pk':
field = self.django_context.get_primary_key_field(currently_observed_model)
continue
field = currently_observed_model._meta.get_field(field_part)
if isinstance(field, RelatedField):
currently_observed_model = field.related_model
model_name = currently_observed_model._meta.model_name
if (model_name is not None
and field_part == (model_name + '_id')):
field = self.django_context.get_primary_key_field(currently_observed_model)
if isinstance(field, ForeignObjectRel):
currently_observed_model = field.related_model
assert field is not None
return field
def resolve_lookup_info_field(self, model_cls: Type[Model], lookup: str) -> Field:
query = Query(model_cls)
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
if lookup_parts:
raise FieldError('Lookups not supported yet')
raise LookupsAreUnsupported()
currently_observed_model = model_cls
current_field = None
for field_part in field_parts:
if field_part == 'pk':
return self.django_context.get_primary_key_field(currently_observed_model)
return self._resolve_field_from_parts(field_parts, model_cls)
current_field = currently_observed_model._meta.get_field(field_part)
if not isinstance(current_field, (ForeignObjectRel, RelatedField)):
continue
def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType:
query = Query(model_cls)
try:
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
if is_expression:
return AnyType(TypeOfAny.explicit)
except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context)
return AnyType(TypeOfAny.from_error)
currently_observed_model = self.django_context.fields_context.get_related_model_cls(current_field)
if isinstance(current_field, ForeignObjectRel):
current_field = self.django_context.get_primary_key_field(currently_observed_model)
field = self._resolve_field_from_parts(field_parts, model_cls)
# if it is None, solve_lookup_type() will fail earlier
assert current_field is not None
return current_field
lookup_cls = None
if lookup_parts:
lookup = lookup_parts[-1]
lookup_cls = field.get_lookup(lookup)
if lookup_cls is None:
# unknown lookup
return AnyType(TypeOfAny.explicit)
if lookup_cls is None or isinstance(lookup_cls, Exact):
return self.django_context.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field)
assert lookup_cls is not None
lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls)
if lookup_info is None:
return AnyType(TypeOfAny.explicit)
for lookup_base in helpers.iter_bases(lookup_info):
if lookup_base.args and isinstance(lookup_base.args[0], Instance):
lookup_type: MypyType = lookup_base.args[0]
# if it's Field, consider lookup_type a __get__ of current field
if (isinstance(lookup_type, Instance)
and lookup_type.type.fullname() == fullnames.FIELD_FULLNAME):
field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__)
if field_info is None:
return AnyType(TypeOfAny.explicit)
lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
is_nullable=field.null)
return lookup_type
return AnyType(TypeOfAny.explicit)
class DjangoContext:
@@ -228,6 +284,27 @@ class DjangoContext:
if isinstance(field, ForeignObjectRel):
yield field
def get_field_lookup_exact_type(self, api: TypeChecker, field: Field) -> MypyType:
if isinstance(field, (RelatedField, ForeignObjectRel)):
related_model_cls = field.related_model
primary_key_field = self.get_primary_key_field(related_model_cls)
primary_key_type = self.fields_context.get_field_get_type(api, primary_key_field, method='init')
rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
if rel_model_info is None:
return AnyType(TypeOfAny.explicit)
model_and_primary_key_type = UnionType.make_union([Instance(rel_model_info, []),
primary_key_type])
return helpers.make_optional(model_and_primary_key_type)
# return helpers.make_optional(Instance(rel_model_info, []))
field_info = helpers.lookup_class_typeinfo(api, field.__class__)
if field_info is None:
return AnyType(TypeOfAny.explicit)
return helpers.get_private_descriptor_type(field_info, '_pyi_lookup_exact_type',
is_nullable=field.null)
def get_primary_key_field(self, model_cls: Type[Model]) -> Field:
for field in model_cls._meta.get_fields():
if isinstance(field, Field):

View File

@@ -1,6 +1,11 @@
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast
from typing import (
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Union, cast,
)
from django.db.models.fields import Field
from django.db.models.fields.related import RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel
from mypy import checker
from mypy.checker import TypeChecker
from mypy.mro import calculate_mro
@@ -115,29 +120,50 @@ def parse_bool(expr: Expression) -> Optional[bool]:
return None
def has_any_of_bases(info: TypeInfo, bases: Set[str]) -> bool:
def has_any_of_bases(info: TypeInfo, bases: Iterable[str]) -> bool:
for base_fullname in bases:
if info.has_base(base_fullname):
return True
return False
def iter_bases(info: TypeInfo) -> Iterator[Instance]:
for base in info.bases:
yield base
yield from iter_bases(base.type)
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType:
""" Return declared type of type_info's private_field_name (used for private Field attributes)"""
sym = type_info.get(private_field_name)
if sym is None:
return AnyType(TypeOfAny.unannotated)
return AnyType(TypeOfAny.explicit)
node = sym.node
if isinstance(node, Var):
descriptor_type = node.type
if descriptor_type is None:
return AnyType(TypeOfAny.unannotated)
return AnyType(TypeOfAny.explicit)
if is_nullable:
descriptor_type = make_optional(descriptor_type)
return descriptor_type
return AnyType(TypeOfAny.unannotated)
return AnyType(TypeOfAny.explicit)
def get_field_lookup_exact_type(api: TypeChecker, field: Field) -> MypyType:
if isinstance(field, (RelatedField, ForeignObjectRel)):
lookup_type_class = field.related_model
rel_model_info = lookup_class_typeinfo(api, lookup_type_class)
if rel_model_info is None:
return AnyType(TypeOfAny.from_error)
return make_optional(Instance(rel_model_info, []))
field_info = lookup_class_typeinfo(api, field.__class__)
if field_info is None:
return AnyType(TypeOfAny.explicit)
return get_private_descriptor_type(field_info, '_pyi_lookup_exact_type',
is_nullable=field.null)
def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]:

View File

@@ -209,6 +209,8 @@ class NewSemanalDjangoPlugin(Plugin):
manager_classes = self._get_current_manager_bases()
if class_fullname in manager_classes and method_name == 'create':
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
if class_fullname in manager_classes and method_name in {'filter', 'get', 'exclude'}:
return partial(init_create.typecheck_queryset_filter, django_context=self.django_context)
return None
def get_base_class_hook(self, fullname: str

View File

@@ -6,7 +6,7 @@ from mypy.types import Instance
from mypy.types import Type as MypyType
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import helpers
from mypy_django_plugin.lib import fullnames, helpers
def get_actual_types(ctx: Union[MethodContext, FunctionContext],
@@ -30,6 +30,12 @@ def get_actual_types(ctx: Union[MethodContext, FunctionContext],
return actual_types
def check_types_compatible(ctx, *, expected_type, actual_type, error_message):
ctx.api.check_subtype(actual_type, expected_type,
ctx.context, error_message,
'got', 'expected')
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
model_cls: Type[Model], method: str) -> MypyType:
typechecker_api = helpers.get_typechecker_api(ctx)
@@ -42,11 +48,11 @@ def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_co
model_cls.__name__),
ctx.context)
continue
typechecker_api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model_cls.__name__),
'got', 'expected')
check_types_compatible(ctx,
expected_type=expected_types[actual_name],
actual_type=actual_type,
error_message='Incompatible type for "{}" of "{}"'.format(actual_name,
model_cls.__name__))
return ctx.default_return_type
@@ -73,3 +79,40 @@ def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: Djan
return ctx.default_return_type
return typecheck_model_method(ctx, django_context, model_cls, 'create')
def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
lookup_kwargs = ctx.arg_names[1]
provided_lookup_types = ctx.arg_types[1]
assert isinstance(ctx.type, Instance)
if not ctx.type.args or not isinstance(ctx.type.args[0], Instance):
return ctx.default_return_type
model_cls_fullname = ctx.type.args[0].type.fullname()
model_cls = django_context.get_model_class_by_fullname(model_cls_fullname)
if model_cls is None:
return ctx.default_return_type
for lookup_kwarg, provided_type in zip(lookup_kwargs, provided_lookup_types):
if lookup_kwarg is None:
continue
# Combinables are not supported yet
if (isinstance(provided_type, Instance)
and provided_type.type.has_base('django.db.models.expressions.Combinable')):
continue
lookup_type = django_context.lookups_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
# Managers as provided_type is not supported yet
if (isinstance(provided_type, Instance)
and helpers.has_any_of_bases(provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME,
fullnames.QUERYSET_CLASS_FULLNAME))):
return ctx.default_return_type
check_types_compatible(ctx,
expected_type=lookup_type,
actual_type=provided_type,
error_message=f'Incompatible type for lookup {lookup_kwarg!r}:')
return ctx.default_return_type

View File

@@ -10,7 +10,9 @@ from mypy.types import AnyType, Instance
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.django.context import (
DjangoContext, LookupsAreUnsupported,
)
from mypy_django_plugin.lib import fullnames, helpers
@@ -38,10 +40,12 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
*, method: str, lookup: str) -> Optional[MypyType]:
try:
lookup_field = django_context.lookups_context.resolve_lookup(model_cls, lookup)
lookup_field = django_context.lookups_context.resolve_lookup_info_field(model_cls, lookup)
except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context)
return None
except LookupsAreUnsupported:
return AnyType(TypeOfAny.explicit)
if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup:
related_model_cls = django_context.fields_context.get_related_model_cls(lookup_field)