Model.__init__ supporting same typing as assigment (#835)

* `Model.__init__` supporting same typing as assigment

* Update mypy_django_plugin/django/context.py
This commit is contained in:
Petter Friberg
2022-01-29 10:07:26 +01:00
committed by GitHub
parent c556668d7a
commit 8aae836a26
9 changed files with 177 additions and 18 deletions

View File

@@ -15,6 +15,7 @@ from django.db.models.lookups import Exact
from django.db.models.sql.query import Query
from django.utils.functional import cached_property
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo
from mypy.plugin import MethodContext
from mypy.types import AnyType, Instance
from mypy.types import Type as MypyType
@@ -174,10 +175,29 @@ class DjangoContext:
field_set_type = self.get_field_set_type(api, primary_key_field, method=method)
expected_types["pk"] = field_set_type
def get_field_set_type_from_model_type_info(info: Optional[TypeInfo], field_name: str) -> Optional[MypyType]:
if info is None:
return None
field_node = info.names.get(field_name)
if field_node is None or not isinstance(field_node.type, Instance):
return None
elif not field_node.type.args:
# Field declares a set and a get type arg. Fallback to `None` when we can't find any args
return None
set_type = field_node.type.args[0]
return set_type
model_info = helpers.lookup_class_typeinfo(api, model_cls)
for field in model_cls._meta.get_fields():
if isinstance(field, Field):
field_name = field.attname
field_set_type = self.get_field_set_type(api, field, method=method)
# Try to retrieve set type from a model's TypeInfo object and fallback to retrieving it manually
# from django-stubs own declaration. This is to align with the setter types declared for
# assignment.
field_set_type = get_field_set_type_from_model_type_info(
model_info, field_name
) or self.get_field_set_type(api, field, method=method)
expected_types[field_name] = field_set_type
if isinstance(field, ForeignKey):

View File

@@ -36,6 +36,7 @@ MIGRATION_CLASS_FULLNAME = "django.db.migrations.migration.Migration"
OPTIONS_CLASS_FULLNAME = "django.db.models.options.Options"
HTTPREQUEST_CLASS_FULLNAME = "django.http.request.HttpRequest"
COMBINABLE_EXPRESSION_FULLNAME = "django.db.models.expressions.Combinable"
F_EXPRESSION_FULLNAME = "django.db.models.expressions.F"
ANY_ATTR_ALLOWED_CLASS_FULLNAME = "django_stubs_ext.AnyAttrAllowed"

View File

@@ -6,7 +6,7 @@ from mypy.nodes import AssignmentStmt, NameExpr, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import AnyType, Instance
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny
from mypy.types import TypeOfAny, UnionType
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
@@ -125,6 +125,11 @@ def set_descriptor_types_for_field(
null_expr = helpers.get_call_argument_by_name(ctx, "null")
if null_expr is not None:
is_nullable = helpers.parse_bool(null_expr) or False
# Allow setting field value to `None` when a field is primary key and has a default that can produce a value
default_expr = helpers.get_call_argument_by_name(ctx, "default")
primary_key_expr = helpers.get_call_argument_by_name(ctx, "primary_key")
if default_expr is not None and primary_key_expr is not None:
is_set_nullable = helpers.parse_bool(primary_key_expr) or False
set_type, get_type = get_field_descriptor_types(
default_return_type.type,
@@ -141,10 +146,46 @@ def determine_type_of_array_field(ctx: FunctionContext, django_context: DjangoCo
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
return default_return_type
base_type = base_field_arg_type.args[1] # extract __get__ type
def drop_combinable(_type: MypyType) -> Optional[MypyType]:
if isinstance(_type, Instance) and _type.type.has_base(fullnames.COMBINABLE_EXPRESSION_FULLNAME):
return None
elif isinstance(_type, UnionType):
items_without_combinable = []
for item in _type.items:
reduced = drop_combinable(item)
if reduced is not None:
items_without_combinable.append(reduced)
if len(items_without_combinable) > 1:
return UnionType(
items_without_combinable,
line=_type.line,
column=_type.column,
is_evaluated=_type.is_evaluated,
uses_pep604_syntax=_type.uses_pep604_syntax,
)
elif len(items_without_combinable) == 1:
return items_without_combinable[0]
else:
return None
return _type
# Both base_field and return type should derive from Field and thus expect 2 arguments
assert len(base_field_arg_type.args) == len(default_return_type.args) == 2
args = []
for default_arg in default_return_type.args:
args.append(helpers.convert_any_to_type(default_arg, base_type))
for new_type, default_arg in zip(base_field_arg_type.args, default_return_type.args):
# Drop any base_field Combinable type
reduced = drop_combinable(new_type)
if reduced is None:
ctx.api.fail(
f"Can't have ArrayField expecting {fullnames.COMBINABLE_EXPRESSION_FULLNAME!r} as data type",
ctx.context,
)
else:
new_type = reduced
args.append(helpers.convert_any_to_type(default_arg, new_type))
return helpers.reparametrize_instance(default_return_type, args)

View File

@@ -29,7 +29,7 @@ def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext)
if lookup_kwarg is None:
continue
if isinstance(provided_type, Instance) and provided_type.type.has_base(
"django.db.models.expressions.Combinable"
fullnames.COMBINABLE_EXPRESSION_FULLNAME
):
provided_type = resolve_combinable_type(provided_type, django_context)