add Model.__init__ typechecking

This commit is contained in:
Maxim Kurnikov
2019-02-08 17:16:03 +03:00
parent dead370244
commit 916df1efb6
16 changed files with 359 additions and 108 deletions

View File

@@ -1,9 +1,13 @@
import typing
from typing import Dict, Optional
from mypy.nodes import MypyFile, TypeInfo, ImportedName, SymbolNode
from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field'
GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField'
@@ -78,3 +82,116 @@ def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile])
if sym is None:
return None
return sym.node
def parse_bool(expr: Expression) -> Optional[bool]:
if isinstance(expr, NameExpr):
if expr.fullname == 'builtins.True':
return True
if expr.fullname == 'builtins.False':
return False
return None
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
return Instance(instance.type, args=new_typevars)
def fill_typevars_with_any(instance: Instance) -> Type:
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
def extract_typevar_value(tp: Instance, typevar_name: str):
return tp.args[tp.type.type_vars.index(typevar_name)]
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if tp.type.has_base(FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
typevar_values: typing.List[Type] = []
for typevar_arg in set_value_type.args:
if isinstance(typevar_arg, TypeVarType):
typevar_values.append(extract_typevar_value(tp, typevar_arg.name))
# if there are typevars, extract from
set_value_type = reparametrize_with(set_value_type, typevar_values)
return set_value_type
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
return get_method.type.ret_type
# GenericForeignKey
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
# only primary keys defined in current class for now
for sym in model.names.values():
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
tp = sym.node.type
if tp.type.metadata.get('django', {}).get('defined_as_primary_key'):
field_type = extract_field_setter_type(tp)
return field_type
return None
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
expected_types: Dict[str, Type] = {}
for base in model.mro:
for name, sym in base.names.items():
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
tp = sym.node.type
field_type = extract_field_setter_type(tp)
if tp.type.fullname() == FOREIGN_KEY_FULLNAME:
ref_to_model = tp.args[0]
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME):
primary_key_type = extract_primary_key_type(ref_to_model.type)
if not primary_key_type:
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types[name + '_id'] = primary_key_type
if field_type:
expected_types[name] = field_type
primary_key_type = extract_primary_key_type(model)
if not primary_key_type:
# no explicit primary key, set pk to Any and add id
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
expected_types['pk'] = primary_key_type
return expected_types
def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]:
"""Return the expression for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
args = ctx.args[idx]
if len(args) != 1:
# Either an error or no value passed.
return None
return args[0]
def get_argument_type_by_name(ctx: FunctionContext, name: str) -> Optional[Type]:
"""Return the type for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
arg_types = ctx.arg_types[idx]
if len(arg_types) != 1:
# Either an error or no value passed.
return None
return arg_types[0]

View File

@@ -8,6 +8,7 @@ from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import Instance, Type
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.helpers import parse_bool
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
from mypy_django_plugin.plugins.models import process_model_class
@@ -54,6 +55,40 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
return ret
def redefine_model_init(ctx: FunctionContext) -> Type:
assert isinstance(ctx.default_return_type, Instance)
api = cast(TypeChecker, ctx.api)
model: TypeInfo = ctx.default_return_type.type
expected_types = helpers.extract_expected_types(ctx, model)
for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]):
if actual_name is None:
# We can't check kwargs reliably.
continue
if actual_name not in expected_types:
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model.name()),
ctx.context)
continue
api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model.name()),
'got', 'expected')
return ctx.default_return_type
def set_primary_key_marking(ctx: FunctionContext) -> Type:
primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key')
if primary_key_arg:
is_primary_key = parse_bool(primary_key_arg)
if is_primary_key:
info = ctx.default_return_type.type
info.metadata.setdefault('django', {})['defined_as_primary_key'] = True
return ctx.default_return_type
class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
@@ -105,6 +140,13 @@ class DjangoPlugin(Plugin):
if fullname in manager_bases:
return determine_proper_manager_type
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
if sym.node.has_base(helpers.FIELD_FULLNAME):
return set_primary_key_marking
if sym.node.metadata.get('django', {}).get('generated_init'):
return redefine_model_init
def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]:
if fullname in {'django.apps.registry.Apps.get_model',

View File

@@ -1,13 +1,12 @@
from mypy.plugin import FunctionContext
from mypy.types import Type, Instance
from mypy_django_plugin import helpers
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
if 'base_field' not in ctx.callee_arg_names:
return ctx.default_return_type
base_field_arg_type = ctx.arg_types[ctx.callee_arg_names.index('base_field')][0]
if not isinstance(base_field_arg_type, Instance):
base_field_arg_type = helpers.get_argument_type_by_name(ctx, 'base_field')
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
return ctx.default_return_type
get_method = base_field_arg_type.type.get_method('__get__')

View File

@@ -2,11 +2,12 @@ from abc import ABCMeta, abstractmethod
from typing import Dict, Iterator, Optional, Tuple, cast
import dataclasses
from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, \
StrExpr, SymbolTableNode, TypeInfo, Var
from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, \
MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, Argument, ARG_STAR2
from mypy.plugin import ClassDefContext
from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance
from mypy.types import Instance, AnyType, TypeOfAny, NoneTyp
from mypy_django_plugin import helpers
@@ -199,16 +200,27 @@ def extract_ref_to_fullname(rvalue_expr: CallExpr,
return None
def add_dummy_init_method(ctx: ClassDefContext) -> None:
any = AnyType(TypeOfAny.special_form)
var = Var('kwargs', any)
kw_arg = Argument(variable=var, type_annotation=any, initializer=None, kind=ARG_STAR2)
add_method(ctx, '__init__', [kw_arg], NoneTyp())
# mark as model class
ctx.cls.info.metadata.setdefault('django', {})['generated_init'] = True
def process_model_class(ctx: ClassDefContext) -> None:
initializers = [
InjectAnyAsBaseForNestedMeta,
AddDefaultObjectsManager,
AddIdAttributeIfPrimaryKeyTrueIsNotSet,
SetIdAttrsForRelatedFields,
AddRelatedManagers
AddRelatedManagers,
]
for initializer_cls in initializers:
initializer_cls.from_ctx(ctx).run()
add_dummy_init_method(ctx)
# allow unspecified attributes for now
ctx.cls.info.fallback_to_any = True

View File

@@ -1,20 +1,12 @@
import typing
from typing import Optional, cast
from mypy.checker import TypeChecker
from mypy.nodes import StrExpr, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
from mypy.types import CallableType, Instance, Type
from mypy_django_plugin import helpers
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
return Instance(instance.type, args=new_typevars)
def fill_typevars_with_any(instance: Instance) -> Type:
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
from mypy_django_plugin.helpers import fill_typevars_with_any, reparametrize_with
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: