Files
django-stubs/mypy_django_plugin/transformers/fields.py
2019-07-24 13:39:23 +03:00

106 lines
4.6 KiB
Python

from typing import Optional, Tuple, cast
from django.db.models.fields.related import RelatedField
from mypy.nodes import AssignmentStmt, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny
from django.db.models.fields import Field
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]:
outer_model_info = ctx.api.scope.active_class()
assert isinstance(outer_model_info, TypeInfo)
if not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
return None
field_name = None
for stmt in outer_model_info.defn.defs.body:
if isinstance(stmt, AssignmentStmt):
if stmt.rvalue == ctx.context:
field_name = stmt.lvalues[0].name
break
if field_name is None:
return None
model_cls = django_context.get_model_class_by_fullname(outer_model_info.fullname())
if model_cls is None:
return None
current_field = model_cls._meta.get_field(field_name)
return current_field
def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
current_field = _get_current_field_from_assignment(ctx, django_context)
if current_field is None:
return AnyType(TypeOfAny.from_error)
assert isinstance(current_field, RelatedField)
related_model = related_model_to_set = current_field.related_model
if related_model_to_set._meta.proxy_for_model:
related_model_to_set = related_model._meta.proxy_for_model
related_model_info = helpers.lookup_class_typeinfo(ctx.api, related_model)
related_model_to_set_info = helpers.lookup_class_typeinfo(ctx.api, related_model_to_set)
default_related_field_type = set_descriptor_types_for_field(ctx)
# replace Any with referred_to_type
args = [
helpers.convert_any_to_type(default_related_field_type.args[0], Instance(related_model_to_set_info, [])),
helpers.convert_any_to_type(default_related_field_type.args[1], Instance(related_model_info, [])),
]
return helpers.reparametrize_instance(default_related_field_type, new_args=args)
def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple[MypyType, MypyType]:
set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
is_nullable=is_nullable)
get_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
is_nullable=is_nullable)
return set_type, get_type
def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
default_return_type = cast(Instance, ctx.default_return_type)
is_nullable = helpers.parse_bool(helpers.get_call_argument_by_name(ctx, 'null'))
set_type, get_type = get_field_descriptor_types(default_return_type.type, is_nullable)
return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
def determine_type_of_array_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
default_return_type = set_descriptor_types_for_field(ctx)
base_field_arg_type = helpers.get_call_argument_type_by_name(ctx, 'base_field')
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
return default_return_type
base_type = base_field_arg_type.args[1] # extract __get__ type
args = []
for default_arg in default_return_type.args:
args.append(helpers.convert_any_to_type(default_arg, base_type))
return helpers.reparametrize_instance(default_return_type, args)
def transform_into_proper_return_type(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
default_return_type = ctx.default_return_type
assert isinstance(default_return_type, Instance)
outer_model_info = ctx.api.scope.active_class()
if not outer_model_info or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
# not inside models.Model class
return ctx.default_return_type
assert isinstance(outer_model_info, TypeInfo)
if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
return fill_descriptor_types_for_related_field(ctx, django_context)
if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME):
return determine_type_of_array_field(ctx, django_context)
return set_descriptor_types_for_field(ctx)