preliminary support for strict_optional

This commit is contained in:
Maxim Kurnikov
2019-02-17 18:07:53 +03:00
parent 6763217a80
commit e9f9202ed1
23 changed files with 614 additions and 343 deletions

View File

@@ -0,0 +1,166 @@
from typing import Dict, Optional, Set, cast
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo, Var
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance, Type, TypeOfAny
from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import extract_field_setter_type, extract_explicit_set_type_of_model_primary_key, get_fields_metadata
from mypy_django_plugin.transformers.fields import get_private_descriptor_type
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
pointer_args: Set[str] = set()
for base in model.bases:
if base.type.has_base(helpers.MODEL_CLASS_FULLNAME):
parent_name = base.type.name().lower()
pointer_args.add(f'{parent_name}_ptr')
pointer_args.add(f'{parent_name}_ptr_id')
return pointer_args
def redefine_and_typecheck_model_init(ctx: FunctionContext) -> Type:
assert isinstance(ctx.default_return_type, Instance)
api = cast(TypeChecker, ctx.api)
model: TypeInfo = ctx.default_return_type.type
expected_types = extract_expected_types(ctx, model)
# order is preserved, can use for positionals
positional_names = list(expected_types.keys())
positional_names.remove('pk')
visited_positionals = set()
# check positionals
for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
actual_pos_name = positional_names[i]
api.check_subtype(actual_pos_type, expected_types[actual_pos_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_pos_name,
model.name()),
'got', 'expected')
visited_positionals.add(actual_pos_name)
# extract name of base models for _ptr
base_pointer_args = extract_base_pointer_args(model)
# check kwargs
for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])):
if actual_name in base_pointer_args:
# parent_ptr args are not supported
continue
if actual_name in visited_positionals:
continue
if actual_name is None:
# unpacked dict as kwargs is not supported
continue
if actual_name not in expected_types:
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model.name()),
ctx.context)
continue
api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model.name()),
'got', 'expected')
return ctx.default_return_type
def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type:
api = cast(TypeChecker, ctx.api)
if isinstance(ctx.type, Instance) and len(ctx.type.args) > 0:
model_generic_arg = ctx.type.args[0]
else:
model_generic_arg = ctx.default_return_type
if isinstance(model_generic_arg, AnyType):
return ctx.default_return_type
model: TypeInfo = model_generic_arg.type
# extract name of base models for _ptr
base_pointer_args = extract_base_pointer_args(model)
expected_types = extract_expected_types(ctx, model)
for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]):
if actual_name in base_pointer_args:
# parent_ptr args are not supported
continue
if actual_name is None:
# unpacked dict as kwargs is not supported
continue
if actual_name not in expected_types:
api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model.name()),
ctx.context)
continue
api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model.name()),
'got', 'expected')
return ctx.default_return_type
def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
field_metadata = get_fields_metadata(model).get(field_name, {})
if 'choices' in field_metadata:
return field_metadata['choices']
return None
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
api = cast(TypeChecker, ctx.api)
expected_types: Dict[str, Type] = {}
primary_key_type = extract_explicit_set_type_of_model_primary_key(model)
if not primary_key_type:
# no explicit primary key, set pk to Any and add id
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
expected_types['pk'] = primary_key_type
for base in model.mro:
# extract all fields for all models in MRO
for name, sym in base.names.items():
# do not redefine special attrs
if name in {'_meta', 'pk'}:
continue
if isinstance(sym.node, Var):
typ = sym.node.type
if typ is None or isinstance(typ, AnyType):
# types are not ready, fallback to Any
expected_types[name] = AnyType(TypeOfAny.from_unimported_type)
expected_types[name + '_id'] = AnyType(TypeOfAny.from_unimported_type)
elif isinstance(typ, Instance):
field_type = extract_field_setter_type(typ)
if field_type is None:
continue
if typ.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
primary_key_type = AnyType(TypeOfAny.implementation_artifact)
# in case it's optional, we need Instance type
referred_to_model = typ.args[1]
is_nullable = helpers.is_optional(referred_to_model)
if is_nullable:
referred_to_model = helpers.make_required(typ.args[1])
if isinstance(referred_to_model, Instance) and referred_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
pk_type = extract_explicit_set_type_of_model_primary_key(referred_to_model.type)
if not pk_type:
# extract set type of AutoField
autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField')
pk_type = get_private_descriptor_type(autofield_info, '_pyi_private_set_type',
is_nullable=is_nullable)
primary_key_type = pk_type
expected_types[name + '_id'] = primary_key_type
if field_type:
expected_types[name] = field_type
return expected_types