preliminary support for strict_optional

This commit is contained in:
Maxim Kurnikov
2019-02-17 18:07:53 +03:00
parent 6763217a80
commit e9f9202ed1
23 changed files with 614 additions and 343 deletions

View File

@@ -5,10 +5,12 @@ from mypy.checker import TypeChecker
from mypy.nodes import AssignmentStmt, ClassDef, Expression, FuncDef, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, \
TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType
from mypy.types import AnyType, CallableType, Instance, NoneTyp, Type, TypeOfAny, TypeVarType, UnionType
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field'
ARRAY_FIELD_FULLNAME = 'django.contrib.postgres.fields.array.ArrayField'
AUTO_FIELD_FULLNAME = 'django.db.models.fields.AutoField'
GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
@@ -95,12 +97,13 @@ def parse_bool(expr: Expression) -> Optional[bool]:
return None
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
return Instance(instance.type, args=new_typevars)
def reparametrize_instance(instance: Instance, new_args: typing.List[Type]) -> Instance:
return Instance(instance.type, args=new_args,
line=instance.line, column=instance.column)
def fill_typevars_with_any(instance: Instance) -> Type:
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
def fill_typevars_with_any(instance: Instance) -> Instance:
return reparametrize_instance(instance, [AnyType(TypeOfAny.unannotated)])
def extract_typevar_value(tp: Instance, typevar_name: str) -> Type:
@@ -117,7 +120,7 @@ def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance:
for typevar_arg in type_to_fill.args:
if isinstance(typevar_arg, TypeVarType):
typevar_values.append(extract_typevar_value(tp, typevar_arg.name))
return reparametrize_with(type_to_fill, typevar_values)
return Instance(type_to_fill.type, typevar_values)
def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]:
@@ -189,28 +192,12 @@ def iter_over_assignments(
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
""" Extract __set__ value of a field. """
if tp.type.has_base(FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
set_value_type = fill_typevars(tp, set_value_type)
return set_value_type
elif isinstance(set_value_type, UnionType):
items_no_typevars = []
for item in set_value_type.items:
if isinstance(item, Instance):
item = fill_typevars(tp, item)
items_no_typevars.append(item)
return UnionType(items_no_typevars)
field_getter_type = extract_field_getter_type(tp)
if field_getter_type:
return field_getter_type
return tp.args[0]
# GenericForeignKey
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
@@ -218,9 +205,7 @@ def extract_field_getter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
if tp.type.has_base(FIELD_FULLNAME):
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
return get_method.type.ret_type
return tp.args[1]
# GenericForeignKey
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
@@ -240,7 +225,10 @@ def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
return get_django_metadata(model).setdefault('fields', {})
def extract_primary_key_type_for_set(model: TypeInfo) -> Optional[Type]:
def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[Type]:
"""
If field with primary_key=True is set on the model, extract its __set__ type.
"""
for field_name, props in get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
@@ -254,3 +242,30 @@ def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]:
if is_primary_key:
return extract_field_getter_type(model.names[field_name].type)
return None
def make_optional(typ: Type):
return UnionType.make_simplified_union([typ, NoneTyp()])
def make_required(typ: Type) -> Type:
if not isinstance(typ, UnionType):
return typ
items = [item for item in typ.items if not isinstance(item, NoneTyp)]
# will reduce to Instance, if only one item
return UnionType.make_union(items)
def is_optional(typ: Type) -> bool:
if not isinstance(typ, UnionType):
return False
return any([isinstance(item, NoneTyp) for item in typ.items])
def has_any_of_bases(info: TypeInfo, bases: typing.Sequence[str]) -> bool:
for base_fullname in bases:
if info.has_base(base_fullname):
return True
return False

View File

@@ -1,19 +1,18 @@
import os
from typing import Callable, Dict, Optional, cast
from typing import Callable, Dict, Optional, Union, cast
from mypy.checker import TypeChecker
from mypy.nodes import MemberExpr, TypeInfo
from mypy.options import Options
from mypy.plugin import AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType, UnionType
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.config import Config
from mypy_django_plugin.plugins import init_create
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations, get_string_value_from_expr
from mypy_django_plugin.plugins.models import process_model_class
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
from mypy_django_plugin.plugins.settings import AddSettingValuesToDjangoConfObject, get_settings_metadata
from mypy_django_plugin.transformers import fields, init_create
from mypy_django_plugin.transformers.migrations import determine_model_cls_from_string_for_migrations, \
get_string_value_from_expr
from mypy_django_plugin.transformers.models import process_model_class
from mypy_django_plugin.transformers.settings import AddSettingValuesToDjangoConfObject, get_settings_metadata
def transform_model_class(ctx: ClassDefContext) -> None:
@@ -50,7 +49,7 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
if base.type.fullname() in {helpers.MANAGER_CLASS_FULLNAME,
helpers.RELATED_MANAGER_CLASS_FULLNAME,
helpers.BASE_MANAGER_CLASS_FULLNAME}:
ret.type.bases[i] = reparametrize_with(base, [Instance(outer_model_info, [])])
ret.type.bases[i] = Instance(base.type, [Instance(outer_model_info, [])])
return ret
return ret
@@ -84,6 +83,17 @@ def return_user_model_hook(ctx: FunctionContext) -> Type:
return TypeType(Instance(model_info, []))
def _extract_referred_to_type_info(typ: Union[UnionType, Instance]) -> Optional[TypeInfo]:
if isinstance(typ, Instance):
return typ.type
else:
# should be Union[TYPE, None]
typ = helpers.make_required(typ)
if isinstance(typ, Instance):
return typ.type
return None
def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type:
if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'):
return ctx.default_attr_type
@@ -94,14 +104,22 @@ def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: Attribu
field_name = ctx.context.name.split('_')[0]
sym = ctx.type.type.get(field_name)
if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0:
to_arg = sym.type.args[0]
if isinstance(to_arg, AnyType):
return AnyType(TypeOfAny.special_form)
referred_to = sym.type.args[1]
if isinstance(referred_to, AnyType):
return AnyType(TypeOfAny.implementation_artifact)
model_type = _extract_referred_to_type_info(referred_to)
if model_type is None:
return AnyType(TypeOfAny.implementation_artifact)
model_type: TypeInfo = to_arg.type
primary_key_type = helpers.extract_primary_key_type_for_get(model_type)
if primary_key_type:
return primary_key_type
is_nullable = helpers.get_fields_metadata(ctx.type.type).get(field_name, {}).get('null', False)
if is_nullable:
return helpers.make_optional(ctx.default_attr_type)
return ctx.default_attr_type
@@ -179,26 +197,19 @@ class DjangoPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FIELD_FULLNAME):
return fields.adjust_return_type_of_field_instantiation
if fullname == 'django.contrib.auth.get_user_model':
return return_user_model_hook
if fullname in {helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME,
helpers.MANYTOMANY_FIELD_FULLNAME}:
return extract_to_parameter_as_get_ret_type_for_related_field
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field
manager_bases = self._get_current_manager_bases()
if fullname in manager_bases:
return determine_proper_manager_type
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
if sym.node.has_base(helpers.FIELD_FULLNAME):
return record_field_properties_into_outer_model_class
if sym.node.metadata.get('django', {}).get('generated_init'):
return init_create.redefine_and_typecheck_model_init

View File

@@ -1,74 +0,0 @@
from typing import cast
from mypy.checker import TypeChecker
from mypy.nodes import ListExpr, NameExpr, TupleExpr
from mypy.plugin import FunctionContext
from mypy.types import Instance, TupleType, Type
from mypy_django_plugin import helpers
from mypy_django_plugin.plugins.models import iter_over_assignments
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
base_field_arg_type = helpers.get_argument_type_by_name(ctx, 'base_field')
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
return ctx.default_return_type
get_method = base_field_arg_type.type.get_method('__get__')
if not get_method:
# not a method
return ctx.default_return_type
return ctx.api.named_generic_type(ctx.context.callee.fullname,
args=[get_method.type.ret_type])
def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
outer_model = api.scope.active_class()
if outer_model is None or not outer_model.has_base(helpers.MODEL_CLASS_FULLNAME):
# outside models.Model class, undetermined
return ctx.default_return_type
field_name = None
for name_expr, stmt in iter_over_assignments(outer_model.defn):
if stmt == ctx.context and isinstance(name_expr, NameExpr):
field_name = name_expr.name
break
if field_name is None:
return ctx.default_return_type
fields_metadata = outer_model.metadata.setdefault('django', {}).setdefault('fields', {})
# primary key
is_primary_key = False
primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key')
if primary_key_arg:
is_primary_key = helpers.parse_bool(primary_key_arg)
fields_metadata[field_name] = {'primary_key': is_primary_key}
# choices
choices_arg = helpers.get_argument_by_name(ctx, 'choices')
if choices_arg and isinstance(choices_arg, (TupleExpr, ListExpr)):
# iterable of 2 element tuples of two kinds
_, analyzed_choices = api.analyze_iterable_item_type(choices_arg)
if isinstance(analyzed_choices, TupleType):
first_element_type = analyzed_choices.items[0]
if isinstance(first_element_type, Instance):
fields_metadata[field_name]['choices'] = first_element_type.type.fullname()
# nullability
null_arg = helpers.get_argument_by_name(ctx, 'null')
is_nullable = False
if null_arg:
is_nullable = helpers.parse_bool(null_arg)
fields_metadata[field_name]['null'] = is_nullable
# is_blankable
blank_arg = helpers.get_argument_by_name(ctx, 'blank')
is_blankable = False
if blank_arg:
is_blankable = helpers.parse_bool(blank_arg)
fields_metadata[field_name]['blank'] = is_blankable
return ctx.default_return_type

View File

@@ -1,62 +0,0 @@
from typing import Optional, cast
from mypy.checker import TypeChecker
from mypy.nodes import StrExpr, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import CallableType, Instance, Type
from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import fill_typevars_with_any, reparametrize_with
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.callee_arg_names:
# shouldn't happen, invalid code
api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
context=ctx.context)
return None
arg_type = ctx.arg_types[ctx.callee_arg_names.index('to')][0]
if not isinstance(arg_type, CallableType):
to_arg_expr = ctx.args[ctx.callee_arg_names.index('to')][0]
if not isinstance(to_arg_expr, StrExpr):
# not string, not supported
return None
try:
model_fullname = helpers.get_model_fullname_from_string(to_arg_expr.value,
all_modules=api.modules)
except helpers.SelfReference:
model_fullname = api.tscope.classes[-1].fullname()
if model_fullname is None:
return None
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
all_modules=api.modules)
if model_info is None or not isinstance(model_info, TypeInfo):
return None
return Instance(model_info, [])
referred_to_type = arg_type.ret_type
if not isinstance(referred_to_type, Instance):
return None
if not referred_to_type.type.has_base(helpers.MODEL_CLASS_FULLNAME):
ctx.api.msg.fail(f'to= parameter value must be '
f'a subclass of {helpers.MODEL_CLASS_FULLNAME}',
context=ctx.context)
return None
return referred_to_type
def extract_to_parameter_as_get_ret_type_for_related_field(ctx: FunctionContext) -> Type:
try:
referred_to_type = get_valid_to_value_or_none(ctx)
except helpers.InvalidModelString as exc:
ctx.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', ctx.context)
return fill_typevars_with_any(ctx.default_return_type)
if referred_to_type is None:
# couldn't extract to= value
return fill_typevars_with_any(ctx.default_return_type)
return reparametrize_with(ctx.default_return_type, [referred_to_type])

View File

@@ -0,0 +1,199 @@
from typing import Optional, cast
from mypy.checker import TypeChecker
from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo, Var
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, TupleType, Type, TypeOfAny, UnionType
from mypy_django_plugin import helpers
from mypy_django_plugin.transformers.models import iter_over_assignments
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.callee_arg_names:
api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
context=ctx.context)
return None
arg_type = ctx.arg_types[ctx.callee_arg_names.index('to')][0]
if not isinstance(arg_type, CallableType):
to_arg_expr = ctx.args[ctx.callee_arg_names.index('to')][0]
if not isinstance(to_arg_expr, StrExpr):
# not string, not supported
return None
try:
model_fullname = helpers.get_model_fullname_from_string(to_arg_expr.value,
all_modules=api.modules)
except helpers.SelfReference:
model_fullname = api.tscope.classes[-1].fullname()
if model_fullname is None:
return None
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
all_modules=api.modules)
if model_info is None or not isinstance(model_info, TypeInfo):
return None
return Instance(model_info, [])
referred_to_type = arg_type.ret_type
if not isinstance(referred_to_type, Instance):
return None
if not referred_to_type.type.has_base(helpers.MODEL_CLASS_FULLNAME):
ctx.api.msg.fail(f'to= parameter value must be '
f'a subclass of {helpers.MODEL_CLASS_FULLNAME}',
context=ctx.context)
return None
return referred_to_type
def convert_any_to_type(typ: Type, referred_to_type: Type) -> Type:
if isinstance(typ, UnionType):
converted_items = []
for item in typ.items:
converted_items.append(convert_any_to_type(item, referred_to_type))
return UnionType.make_simplified_union(converted_items,
line=typ.line, column=typ.column)
if isinstance(typ, Instance):
args = []
for default_arg in typ.args:
if isinstance(default_arg, AnyType):
args.append(referred_to_type)
else:
args.append(default_arg)
return helpers.reparametrize_instance(typ, args)
if isinstance(typ, AnyType):
return referred_to_type
return typ
def _extract_referred_to_type(ctx: FunctionContext) -> Optional[Type]:
try:
referred_to_type = get_valid_to_value_or_none(ctx)
except helpers.InvalidModelString as exc:
ctx.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', ctx.context)
return None
return referred_to_type
def fill_descriptor_types_for_related_field(ctx: FunctionContext) -> Type:
default_return_type = set_descriptor_types_for_field(ctx)
referred_to_type = _extract_referred_to_type(ctx)
if referred_to_type is None:
return default_return_type
# replace Any with referred_to_type
args = []
for default_arg in default_return_type.args:
args.append(convert_any_to_type(default_arg, referred_to_type))
return helpers.reparametrize_instance(ctx.default_return_type, new_args=args)
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> Type:
node = type_info.get(private_field_name).node
if isinstance(node, Var):
descriptor_type = node.type
if is_nullable:
descriptor_type = helpers.make_optional(descriptor_type)
return descriptor_type
return AnyType(TypeOfAny.unannotated)
def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
default_return_type = cast(Instance, ctx.default_return_type)
is_nullable = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'null'))
set_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type',
is_nullable=is_nullable)
get_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_get_type',
is_nullable=is_nullable)
return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
default_return_type = set_descriptor_types_for_field(ctx)
base_field_arg_type = helpers.get_argument_type_by_name(ctx, 'base_field')
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
return default_return_type
base_type = base_field_arg_type.args[1] # extract __get__ type
args = []
for default_arg in default_return_type.args:
args.append(convert_any_to_type(default_arg, base_type))
return helpers.reparametrize_instance(default_return_type, args)
def transform_into_proper_return_type(ctx: FunctionContext) -> Type:
default_return_type = ctx.default_return_type
if not isinstance(default_return_type, Instance):
return default_return_type
if helpers.has_any_of_bases(default_return_type.type, (helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME,
helpers.MANYTOMANY_FIELD_FULLNAME)):
return fill_descriptor_types_for_related_field(ctx)
if default_return_type.type.has_base(helpers.ARRAY_FIELD_FULLNAME):
return determine_type_of_array_field(ctx)
return set_descriptor_types_for_field(ctx)
def adjust_return_type_of_field_instantiation(ctx: FunctionContext) -> Type:
record_field_properties_into_outer_model_class(ctx)
return transform_into_proper_return_type(ctx)
def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> None:
api = cast(TypeChecker, ctx.api)
outer_model = api.scope.active_class()
if outer_model is None or not outer_model.has_base(helpers.MODEL_CLASS_FULLNAME):
# outside models.Model class, undetermined
return
field_name = None
for name_expr, stmt in iter_over_assignments(outer_model.defn):
if stmt == ctx.context and isinstance(name_expr, NameExpr):
field_name = name_expr.name
break
if field_name is None:
return
fields_metadata = outer_model.metadata.setdefault('django', {}).setdefault('fields', {})
# primary key
is_primary_key = False
primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key')
if primary_key_arg:
is_primary_key = helpers.parse_bool(primary_key_arg)
fields_metadata[field_name] = {'primary_key': is_primary_key}
# choices
choices_arg = helpers.get_argument_by_name(ctx, 'choices')
if choices_arg and isinstance(choices_arg, (TupleExpr, ListExpr)):
# iterable of 2 element tuples of two kinds
_, analyzed_choices = api.analyze_iterable_item_type(choices_arg)
if isinstance(analyzed_choices, TupleType):
first_element_type = analyzed_choices.items[0]
if isinstance(first_element_type, Instance):
fields_metadata[field_name]['choices'] = first_element_type.type.fullname()
# nullability
null_arg = helpers.get_argument_by_name(ctx, 'null')
is_nullable = False
if null_arg:
is_nullable = helpers.parse_bool(null_arg)
fields_metadata[field_name]['null'] = is_nullable
# is_blankable
blank_arg = helpers.get_argument_by_name(ctx, 'blank')
is_blankable = False
if blank_arg:
is_blankable = helpers.parse_bool(blank_arg)
fields_metadata[field_name]['blank'] = is_blankable

View File

@@ -3,10 +3,10 @@ from typing import Dict, Optional, Set, cast
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo, Var
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType
from mypy.types import AnyType, Instance, Type, TypeOfAny
from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import extract_field_setter_type, extract_primary_key_type_for_set, get_fields_metadata
from mypy_django_plugin.helpers import extract_field_setter_type, extract_explicit_set_type_of_model_primary_key, get_fields_metadata
from mypy_django_plugin.transformers.fields import get_private_descriptor_type
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
@@ -112,40 +112,54 @@ def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
expected_types: Dict[str, Type] = {}
api = cast(TypeChecker, ctx.api)
primary_key_type = extract_primary_key_type_for_set(model)
expected_types: Dict[str, Type] = {}
primary_key_type = extract_explicit_set_type_of_model_primary_key(model)
if not primary_key_type:
# no explicit primary key, set pk to Any and add id
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
expected_types['pk'] = primary_key_type
for base in model.mro:
# extract all fields for all models in MRO
for name, sym in base.names.items():
# do not redefine special attrs
if name in {'_meta', 'pk'}:
continue
if isinstance(sym.node, Var):
if sym.node.type is None or isinstance(sym.node.type, AnyType):
typ = sym.node.type
if typ is None or isinstance(typ, AnyType):
# types are not ready, fallback to Any
expected_types[name] = AnyType(TypeOfAny.from_unimported_type)
expected_types[name + '_id'] = AnyType(TypeOfAny.from_unimported_type)
elif isinstance(sym.node.type, Instance):
tp = sym.node.type
field_type = extract_field_setter_type(tp)
elif isinstance(typ, Instance):
field_type = extract_field_setter_type(typ)
if field_type is None:
continue
if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
ref_to_model = tp.args[0]
primary_key_type = AnyType(TypeOfAny.special_form)
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
typ = extract_primary_key_type_for_set(ref_to_model.type)
if typ:
primary_key_type = typ
if typ.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
primary_key_type = AnyType(TypeOfAny.implementation_artifact)
# in case it's optional, we need Instance type
referred_to_model = typ.args[1]
is_nullable = helpers.is_optional(referred_to_model)
if is_nullable:
referred_to_model = helpers.make_required(typ.args[1])
if isinstance(referred_to_model, Instance) and referred_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
pk_type = extract_explicit_set_type_of_model_primary_key(referred_to_model.type)
if not pk_type:
# extract set type of AutoField
autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField')
pk_type = get_private_descriptor_type(autofield_info, '_pyi_private_set_type',
is_nullable=is_nullable)
primary_key_type = pk_type
expected_types[name + '_id'] = primary_key_type
if field_type:
expected_types[name] = field_type

View File

@@ -4,7 +4,6 @@ from mypy.checker import TypeChecker
from mypy.nodes import Expression, StrExpr, TypeInfo
from mypy.plugin import MethodContext
from mypy.types import Instance, Type, TypeType
from mypy_django_plugin import helpers

View File

@@ -73,8 +73,6 @@ def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExp
class SetIdAttrsForRelatedFields(ModelClassInitializer):
def run(self) -> None:
for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef):
# base_model_info = self.api.named_type('builtins.object').type
# helpers.get_related_field_primary_key_names(base_model_info).append(node_name)
node_name = lvalue.name + '_id'
self.add_new_node_to_model_class(name=node_name,
typ=self.api.builtin_type('builtins.int'))