mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-08 13:04:47 +08:00
add support for typechecking of filter/get/exclude arguments (#183)
* add support for typechecking of filter/get/exclude arguments * linting
This commit is contained in:
@@ -2,7 +2,7 @@ import os
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
TYPE_CHECKING, Dict, Iterator, Optional, Set, Tuple, Type, Union,
|
||||
TYPE_CHECKING, Dict, Iterable, Iterator, Optional, Set, Tuple, Type, Union,
|
||||
)
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
@@ -11,14 +11,16 @@ from django.db.models.base import Model
|
||||
from django.db.models.fields import AutoField, CharField, Field
|
||||
from django.db.models.fields.related import ForeignKey, RelatedField
|
||||
from django.db.models.fields.reverse_related import ForeignObjectRel
|
||||
from django.db.models.lookups import Exact
|
||||
from django.db.models.sql.query import Query
|
||||
from django.utils.functional import cached_property
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.plugin import MethodContext
|
||||
from mypy.types import AnyType, Instance
|
||||
from mypy.types import Type as MypyType
|
||||
from mypy.types import TypeOfAny
|
||||
from mypy.types import TypeOfAny, UnionType
|
||||
|
||||
from mypy_django_plugin.lib import helpers
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
try:
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
@@ -153,33 +155,87 @@ class DjangoFieldsContext:
|
||||
return related_model_cls
|
||||
|
||||
|
||||
class LookupsAreUnsupported(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DjangoLookupsContext:
|
||||
def __init__(self, django_context: 'DjangoContext'):
|
||||
self.django_context = django_context
|
||||
|
||||
def resolve_lookup(self, model_cls: Type[Model], lookup: str) -> Field:
|
||||
def _resolve_field_from_parts(self, field_parts: Iterable[str], model_cls: Type[Model]) -> Field:
|
||||
currently_observed_model = model_cls
|
||||
field = None
|
||||
for field_part in field_parts:
|
||||
if field_part == 'pk':
|
||||
field = self.django_context.get_primary_key_field(currently_observed_model)
|
||||
continue
|
||||
|
||||
field = currently_observed_model._meta.get_field(field_part)
|
||||
if isinstance(field, RelatedField):
|
||||
currently_observed_model = field.related_model
|
||||
model_name = currently_observed_model._meta.model_name
|
||||
if (model_name is not None
|
||||
and field_part == (model_name + '_id')):
|
||||
field = self.django_context.get_primary_key_field(currently_observed_model)
|
||||
|
||||
if isinstance(field, ForeignObjectRel):
|
||||
currently_observed_model = field.related_model
|
||||
|
||||
assert field is not None
|
||||
return field
|
||||
|
||||
def resolve_lookup_info_field(self, model_cls: Type[Model], lookup: str) -> Field:
|
||||
query = Query(model_cls)
|
||||
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
|
||||
if lookup_parts:
|
||||
raise FieldError('Lookups not supported yet')
|
||||
raise LookupsAreUnsupported()
|
||||
|
||||
currently_observed_model = model_cls
|
||||
current_field = None
|
||||
for field_part in field_parts:
|
||||
if field_part == 'pk':
|
||||
return self.django_context.get_primary_key_field(currently_observed_model)
|
||||
return self._resolve_field_from_parts(field_parts, model_cls)
|
||||
|
||||
current_field = currently_observed_model._meta.get_field(field_part)
|
||||
if not isinstance(current_field, (ForeignObjectRel, RelatedField)):
|
||||
continue
|
||||
def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType:
|
||||
query = Query(model_cls)
|
||||
try:
|
||||
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
|
||||
if is_expression:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
except FieldError as exc:
|
||||
ctx.api.fail(exc.args[0], ctx.context)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
currently_observed_model = self.django_context.fields_context.get_related_model_cls(current_field)
|
||||
if isinstance(current_field, ForeignObjectRel):
|
||||
current_field = self.django_context.get_primary_key_field(currently_observed_model)
|
||||
field = self._resolve_field_from_parts(field_parts, model_cls)
|
||||
|
||||
# if it is None, solve_lookup_type() will fail earlier
|
||||
assert current_field is not None
|
||||
return current_field
|
||||
lookup_cls = None
|
||||
if lookup_parts:
|
||||
lookup = lookup_parts[-1]
|
||||
lookup_cls = field.get_lookup(lookup)
|
||||
if lookup_cls is None:
|
||||
# unknown lookup
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
if lookup_cls is None or isinstance(lookup_cls, Exact):
|
||||
return self.django_context.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field)
|
||||
|
||||
assert lookup_cls is not None
|
||||
|
||||
lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls)
|
||||
if lookup_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
for lookup_base in helpers.iter_bases(lookup_info):
|
||||
if lookup_base.args and isinstance(lookup_base.args[0], Instance):
|
||||
lookup_type: MypyType = lookup_base.args[0]
|
||||
# if it's Field, consider lookup_type a __get__ of current field
|
||||
if (isinstance(lookup_type, Instance)
|
||||
and lookup_type.type.fullname() == fullnames.FIELD_FULLNAME):
|
||||
field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||
is_nullable=field.null)
|
||||
return lookup_type
|
||||
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
|
||||
class DjangoContext:
|
||||
@@ -228,6 +284,27 @@ class DjangoContext:
|
||||
if isinstance(field, ForeignObjectRel):
|
||||
yield field
|
||||
|
||||
def get_field_lookup_exact_type(self, api: TypeChecker, field: Field) -> MypyType:
|
||||
if isinstance(field, (RelatedField, ForeignObjectRel)):
|
||||
related_model_cls = field.related_model
|
||||
primary_key_field = self.get_primary_key_field(related_model_cls)
|
||||
primary_key_type = self.fields_context.get_field_get_type(api, primary_key_field, method='init')
|
||||
|
||||
rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
|
||||
if rel_model_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
model_and_primary_key_type = UnionType.make_union([Instance(rel_model_info, []),
|
||||
primary_key_type])
|
||||
return helpers.make_optional(model_and_primary_key_type)
|
||||
# return helpers.make_optional(Instance(rel_model_info, []))
|
||||
|
||||
field_info = helpers.lookup_class_typeinfo(api, field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
return helpers.get_private_descriptor_type(field_info, '_pyi_lookup_exact_type',
|
||||
is_nullable=field.null)
|
||||
|
||||
def get_primary_key_field(self, model_cls: Type[Model]) -> Field:
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, Field):
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Union, cast,
|
||||
)
|
||||
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.fields.related import RelatedField
|
||||
from django.db.models.fields.reverse_related import ForeignObjectRel
|
||||
from mypy import checker
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.mro import calculate_mro
|
||||
@@ -115,29 +120,50 @@ def parse_bool(expr: Expression) -> Optional[bool]:
|
||||
return None
|
||||
|
||||
|
||||
def has_any_of_bases(info: TypeInfo, bases: Set[str]) -> bool:
|
||||
def has_any_of_bases(info: TypeInfo, bases: Iterable[str]) -> bool:
|
||||
for base_fullname in bases:
|
||||
if info.has_base(base_fullname):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def iter_bases(info: TypeInfo) -> Iterator[Instance]:
|
||||
for base in info.bases:
|
||||
yield base
|
||||
yield from iter_bases(base.type)
|
||||
|
||||
|
||||
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType:
|
||||
""" Return declared type of type_info's private_field_name (used for private Field attributes)"""
|
||||
sym = type_info.get(private_field_name)
|
||||
if sym is None:
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
node = sym.node
|
||||
if isinstance(node, Var):
|
||||
descriptor_type = node.type
|
||||
if descriptor_type is None:
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
if is_nullable:
|
||||
descriptor_type = make_optional(descriptor_type)
|
||||
return descriptor_type
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
|
||||
def get_field_lookup_exact_type(api: TypeChecker, field: Field) -> MypyType:
|
||||
if isinstance(field, (RelatedField, ForeignObjectRel)):
|
||||
lookup_type_class = field.related_model
|
||||
rel_model_info = lookup_class_typeinfo(api, lookup_type_class)
|
||||
if rel_model_info is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
return make_optional(Instance(rel_model_info, []))
|
||||
|
||||
field_info = lookup_class_typeinfo(api, field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
return get_private_descriptor_type(field_info, '_pyi_lookup_exact_type',
|
||||
is_nullable=field.null)
|
||||
|
||||
|
||||
def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]:
|
||||
|
||||
@@ -209,6 +209,8 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
manager_classes = self._get_current_manager_bases()
|
||||
if class_fullname in manager_classes and method_name == 'create':
|
||||
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
|
||||
if class_fullname in manager_classes and method_name in {'filter', 'get', 'exclude'}:
|
||||
return partial(init_create.typecheck_queryset_filter, django_context=self.django_context)
|
||||
return None
|
||||
|
||||
def get_base_class_hook(self, fullname: str
|
||||
|
||||
@@ -6,7 +6,7 @@ from mypy.types import Instance
|
||||
from mypy.types import Type as MypyType
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import helpers
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
|
||||
def get_actual_types(ctx: Union[MethodContext, FunctionContext],
|
||||
@@ -30,6 +30,12 @@ def get_actual_types(ctx: Union[MethodContext, FunctionContext],
|
||||
return actual_types
|
||||
|
||||
|
||||
def check_types_compatible(ctx, *, expected_type, actual_type, error_message):
|
||||
ctx.api.check_subtype(actual_type, expected_type,
|
||||
ctx.context, error_message,
|
||||
'got', 'expected')
|
||||
|
||||
|
||||
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
|
||||
model_cls: Type[Model], method: str) -> MypyType:
|
||||
typechecker_api = helpers.get_typechecker_api(ctx)
|
||||
@@ -42,11 +48,11 @@ def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_co
|
||||
model_cls.__name__),
|
||||
ctx.context)
|
||||
continue
|
||||
typechecker_api.check_subtype(actual_type, expected_types[actual_name],
|
||||
ctx.context,
|
||||
'Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
model_cls.__name__),
|
||||
'got', 'expected')
|
||||
check_types_compatible(ctx,
|
||||
expected_type=expected_types[actual_name],
|
||||
actual_type=actual_type,
|
||||
error_message='Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
model_cls.__name__))
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
@@ -73,3 +79,40 @@ def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: Djan
|
||||
return ctx.default_return_type
|
||||
|
||||
return typecheck_model_method(ctx, django_context, model_cls, 'create')
|
||||
|
||||
|
||||
def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||
lookup_kwargs = ctx.arg_names[1]
|
||||
provided_lookup_types = ctx.arg_types[1]
|
||||
|
||||
assert isinstance(ctx.type, Instance)
|
||||
|
||||
if not ctx.type.args or not isinstance(ctx.type.args[0], Instance):
|
||||
return ctx.default_return_type
|
||||
|
||||
model_cls_fullname = ctx.type.args[0].type.fullname()
|
||||
model_cls = django_context.get_model_class_by_fullname(model_cls_fullname)
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
for lookup_kwarg, provided_type in zip(lookup_kwargs, provided_lookup_types):
|
||||
if lookup_kwarg is None:
|
||||
continue
|
||||
# Combinables are not supported yet
|
||||
if (isinstance(provided_type, Instance)
|
||||
and provided_type.type.has_base('django.db.models.expressions.Combinable')):
|
||||
continue
|
||||
|
||||
lookup_type = django_context.lookups_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
|
||||
# Managers as provided_type is not supported yet
|
||||
if (isinstance(provided_type, Instance)
|
||||
and helpers.has_any_of_bases(provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME,
|
||||
fullnames.QUERYSET_CLASS_FULLNAME))):
|
||||
return ctx.default_return_type
|
||||
|
||||
check_types_compatible(ctx,
|
||||
expected_type=lookup_type,
|
||||
actual_type=provided_type,
|
||||
error_message=f'Incompatible type for lookup {lookup_kwarg!r}:')
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
@@ -10,7 +10,9 @@ from mypy.types import AnyType, Instance
|
||||
from mypy.types import Type as MypyType
|
||||
from mypy.types import TypeOfAny
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.django.context import (
|
||||
DjangoContext, LookupsAreUnsupported,
|
||||
)
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
|
||||
@@ -38,10 +40,12 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
|
||||
def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
|
||||
*, method: str, lookup: str) -> Optional[MypyType]:
|
||||
try:
|
||||
lookup_field = django_context.lookups_context.resolve_lookup(model_cls, lookup)
|
||||
lookup_field = django_context.lookups_context.resolve_lookup_info_field(model_cls, lookup)
|
||||
except FieldError as exc:
|
||||
ctx.api.fail(exc.args[0], ctx.context)
|
||||
return None
|
||||
except LookupsAreUnsupported:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup:
|
||||
related_model_cls = django_context.fields_context.get_related_model_cls(lookup_field)
|
||||
|
||||
Reference in New Issue
Block a user