fix couple edge cases with __init__

This commit is contained in:
Maxim Kurnikov
2019-02-10 04:32:27 +03:00
parent 5f6f597266
commit 6b7507206a
8 changed files with 172 additions and 71 deletions

View File

@@ -1,9 +1,10 @@
import typing
from typing import Dict, Optional
from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var
from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var, AssignmentStmt, \
CallExpr
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field'
@@ -101,10 +102,23 @@ def fill_typevars_with_any(instance: Instance) -> Type:
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
def extract_typevar_value(tp: Instance, typevar_name: str):
def extract_typevar_value(tp: Instance, typevar_name: str) -> Type:
if typevar_name in {'_T', '_T_co'}:
if '_T' in tp.type.type_vars:
return tp.args[tp.type.type_vars.index('_T')]
if '_T_co' in tp.type.type_vars:
return tp.args[tp.type.type_vars.index('_T_co')]
return tp.args[tp.type.type_vars.index(typevar_name)]
def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance:
typevar_values: typing.List[Type] = []
for typevar_arg in type_to_fill.args:
if isinstance(typevar_arg, TypeVarType):
typevar_values.append(extract_typevar_value(tp, typevar_arg.name))
return reparametrize_with(type_to_fill, typevar_values)
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if tp.type.has_base(FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
@@ -112,13 +126,15 @@ def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
typevar_values: typing.List[Type] = []
for typevar_arg in set_value_type.args:
if isinstance(typevar_arg, TypeVarType):
typevar_values.append(extract_typevar_value(tp, typevar_arg.name))
# if there are typevars, extract from
set_value_type = reparametrize_with(set_value_type, typevar_values)
set_value_type = fill_typevars(tp, set_value_type)
return set_value_type
elif isinstance(set_value_type, UnionType):
items_no_typevars = []
for item in set_value_type.items:
if isinstance(item, Instance):
item = fill_typevars(tp, item)
items_no_typevars.append(item)
return UnionType(items_no_typevars)
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
@@ -131,31 +147,20 @@ def extract_field_setter_type(tp: Instance) -> Optional[Type]:
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
# only primary keys defined in current class for now
for sym in model.names.values():
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
tp = sym.node.type
if tp.type.metadata.get('django', {}).get('defined_as_primary_key'):
field_type = extract_field_setter_type(tp)
return field_type
for stmt in model.defn.defs.body:
if isinstance(stmt, AssignmentStmt) and isinstance(stmt.rvalue, CallExpr):
name_expr = stmt.lvalues[0]
if isinstance(name_expr, NameExpr):
name = name_expr.name
if 'primary_key' in stmt.rvalue.arg_names:
is_primary_key = stmt.rvalue.args[stmt.rvalue.arg_names.index('primary_key')]
if is_primary_key:
return extract_field_setter_type(model.names[name].type)
return None
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
expected_types: Dict[str, Type] = {}
for base in model.mro:
for name, sym in base.names.items():
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
tp = sym.node.type
field_type = extract_field_setter_type(tp)
if tp.type.fullname() == FOREIGN_KEY_FULLNAME:
ref_to_model = tp.args[0]
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME):
primary_key_type = extract_primary_key_type(ref_to_model.type)
if not primary_key_type:
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types[name + '_id'] = primary_key_type
if field_type:
expected_types[name] = field_type
primary_key_type = extract_primary_key_type(model)
if not primary_key_type:
@@ -164,6 +169,21 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, T
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
expected_types['pk'] = primary_key_type
for base in model.mro:
for name, sym in base.names.items():
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
tp = sym.node.type
field_type = extract_field_setter_type(tp)
if tp.type.fullname() in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}:
ref_to_model = tp.args[0]
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME):
primary_key_type = extract_primary_key_type(ref_to_model.type)
if not primary_key_type:
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types[name + '_id'] = primary_key_type
if field_type:
expected_types[name] = field_type
return expected_types

View File

@@ -1,6 +1,6 @@
import os
from configparser import ConfigParser
from typing import Callable, Dict, Optional, cast
from typing import Callable, Dict, Optional, Set, cast
from dataclasses import dataclass
from mypy.checker import TypeChecker
@@ -10,7 +10,6 @@ from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import Instance, Type
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.helpers import parse_bool
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
from mypy_django_plugin.plugins.models import process_model_class
@@ -57,6 +56,16 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
return ret
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
pointer_args: Set[str] = set()
for base in model.bases:
if base.type.has_base(helpers.MODEL_CLASS_FULLNAME):
parent_name = base.type.name().lower()
pointer_args.add(f'{parent_name}_ptr')
pointer_args.add(f'{parent_name}_ptr_id')
return pointer_args
def redefine_model_init(ctx: FunctionContext) -> Type:
assert isinstance(ctx.default_return_type, Instance)
@@ -64,9 +73,33 @@ def redefine_model_init(ctx: FunctionContext) -> Type:
model: TypeInfo = ctx.default_return_type.type
expected_types = helpers.extract_expected_types(ctx, model)
for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]):
# order is preserved, can use for positionals
positional_names = list(expected_types.keys())
positional_names.remove('pk')
visited_positionals = set()
# check positionals
for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
actual_pos_name = positional_names[i]
api.check_subtype(actual_pos_type, expected_types[actual_pos_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_pos_name,
model.name()),
'got', 'expected')
visited_positionals.add(actual_pos_name)
# extract name of base models for _ptr
base_pointer_args = extract_base_pointer_args(model)
# check kwargs
for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])):
if actual_name in base_pointer_args:
# parent_ptr args are not supported
continue
if actual_name in visited_positionals:
continue
if actual_name is None:
# We can't check kwargs reliably.
# unpacked dict as kwargs is not supported
continue
if actual_name not in expected_types:
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
@@ -81,16 +114,6 @@ def redefine_model_init(ctx: FunctionContext) -> Type:
return ctx.default_return_type
def set_primary_key_marking(ctx: FunctionContext) -> Type:
primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key')
if primary_key_arg:
is_primary_key = parse_bool(primary_key_arg)
if is_primary_key:
info = ctx.default_return_type.type
info.metadata.setdefault('django', {})['defined_as_primary_key'] = True
return ctx.default_return_type
@dataclass
class Config:
django_settings_module: Optional[str] = None
@@ -171,8 +194,6 @@ class DjangoPlugin(Plugin):
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
if sym.node.has_base(helpers.FIELD_FULLNAME):
return set_primary_key_marking
if sym.node.metadata.get('django', {}).get('generated_init'):
return redefine_model_init

View File

@@ -3,7 +3,7 @@ from typing import Dict, Iterator, Optional, Tuple, cast
import dataclasses
from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, \
MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, Argument, ARG_STAR2
MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, Argument, ARG_STAR2, ARG_STAR
from mypy.plugin import ClassDefContext
from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzerPass2
@@ -202,9 +202,13 @@ def extract_ref_to_fullname(rvalue_expr: CallExpr,
def add_dummy_init_method(ctx: ClassDefContext) -> None:
any = AnyType(TypeOfAny.special_form)
var = Var('kwargs', any)
kw_arg = Argument(variable=var, type_annotation=any, initializer=None, kind=ARG_STAR2)
add_method(ctx, '__init__', [kw_arg], NoneTyp())
pos_arg = Argument(variable=Var('args', any),
type_annotation=any, initializer=None, kind=ARG_STAR)
kw_arg = Argument(variable=Var('kwargs', any),
type_annotation=any, initializer=None, kind=ARG_STAR2)
add_method(ctx, '__init__', [pos_arg, kw_arg], NoneTyp())
# mark as model class
ctx.cls.info.metadata.setdefault('django', {})['generated_init'] = True