nested class Meta support

This commit is contained in:
Maxim Kurnikov
2018-11-30 14:00:11 +03:00
parent 00f72f97d7
commit 60b1c48ade
13 changed files with 433 additions and 87 deletions

View File

@@ -0,0 +1,41 @@
from typing import Iterator, List, cast
from mypy.nodes import ClassDef, AssignmentStmt, CallExpr
from mypy.plugin import FunctionContext, ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Type, Instance
from mypy_django_plugin.plugins.related_fields import add_new_var_node_to_class
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
if 'base_field' not in ctx.arg_names:
return ctx.default_return_type
base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0]
return ctx.api.named_generic_type(ctx.context.callee.fullname,
args=[base_field_arg_type.type.names['__get__'].type.ret_type])
def get_assignments(klass: ClassDef) -> List[AssignmentStmt]:
stmts = []
for stmt in klass.defs.body:
if not isinstance(stmt, AssignmentStmt):
continue
if len(stmt.lvalues) > 1:
# not supported yet
continue
stmts.append(stmt)
return stmts
def add_int_id_attribute_if_primary_key_true_is_not_present(ctx: ClassDefContext) -> None:
api = cast(SemanticAnalyzerPass2, ctx.api)
for stmt in get_assignments(ctx.cls):
if (isinstance(stmt.rvalue, CallExpr)
and 'primary_key' in stmt.rvalue.arg_names
and api.parse_bool(stmt.rvalue.args[stmt.rvalue.arg_names.index('primary_key')])):
break
else:
add_new_var_node_to_class(ctx.cls.info, 'id', api.builtin_type('builtins.int'))

View File

@@ -0,0 +1,12 @@
from mypy.nodes import TypeInfo
from mypy.plugin import ClassDefContext
def inject_any_as_base_for_nested_class_meta(ctx: ClassDefContext) -> None:
if 'Meta' not in ctx.cls.info.names:
return None
sym = ctx.cls.info.names['Meta']
if not isinstance(sym.node, TypeInfo):
return None
sym.node.fallback_to_any = True

View File

@@ -1,11 +0,0 @@
from mypy.plugin import FunctionContext
from mypy.types import Type
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
if 'base_field' not in ctx.arg_names:
return ctx.default_return_type
base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0]
return ctx.api.named_generic_type(ctx.context.callee.fullname,
args=[base_field_arg_type.type.names['__get__'].type.ret_type])

View File

@@ -3,15 +3,16 @@ from typing import Optional, cast
from django.conf import Settings
from mypy.checker import TypeChecker
from mypy.nodes import SymbolTable, MDEF, AssignmentStmt
from mypy.nodes import MDEF, AssignmentStmt, MypyFile, StrExpr, TypeInfo, NameExpr, Var, SymbolTableNode
from mypy.plugin import FunctionContext, ClassDefContext
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import get_models_file
def extract_related_name_value(ctx: FunctionContext) -> str:
return ctx.context.args[ctx.arg_names.index('related_name')].value
return ctx.args[ctx.arg_names.index('related_name')][0].value
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
@@ -31,8 +32,15 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
if not isinstance(arg_type, CallableType):
# to= defined as string is not supported
return None
to_arg_expr = ctx.args[ctx.arg_names.index('to')][0]
if not isinstance(to_arg_expr, StrExpr):
# not string, not supported
return None
model_info = helpers.get_model_type_from_string(to_arg_expr,
all_modules=cast(TypeChecker, ctx.api).modules)
if model_info is None:
return None
return Instance(model_info, [])
referred_to_type = arg_type.ret_type
for base in referred_to_type.type.bases:
@@ -47,59 +55,38 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
return referred_to_type
class ForeignKeyHook(object):
def __init__(self, settings: Settings):
self.settings = settings
def __call__(self, ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
outer_class_info = api.tscope.classes[-1]
referred_to_type = get_valid_to_value_or_none(ctx)
if referred_to_type is None:
return fill_typevars_with_any(ctx.default_return_type)
if 'related_name' in ctx.arg_names:
related_name = extract_related_name_value(ctx)
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(outer_class_info, [])])
sym = helpers.create_new_symtable_node(related_name, MDEF,
instance=queryset_type)
referred_to_type.type.names[related_name] = sym
return reparametrize_with(ctx.default_return_type, [referred_to_type])
def add_new_var_node_to_class(class_type: TypeInfo, name: str, typ: Instance) -> None:
var = Var(name=name, type=typ)
var.info = typ.type
var._fullname = class_type.fullname() + '.' + name
var.is_inferred = True
var.is_initialized_in_class = True
class_type.names[name] = SymbolTableNode(MDEF, var)
class OneToOneFieldHook(object):
def __init__(self, settings: Optional[Settings]):
self.settings = settings
def __call__(self, ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
outer_class_info = api.tscope.classes[-1]
referred_to_type = get_valid_to_value_or_none(ctx)
if referred_to_type is None:
return fill_typevars_with_any(ctx.default_return_type)
if 'related_name' in ctx.arg_names:
related_name = extract_related_name_value(ctx)
sym = helpers.create_new_symtable_node(related_name, MDEF,
instance=Instance(outer_class_info, []))
referred_to_type.type.names[related_name] = sym
return reparametrize_with(ctx.default_return_type, [referred_to_type])
def extract_to_parameter_as_get_ret_type(ctx: FunctionContext) -> Type:
referred_to_type = get_valid_to_value_or_none(ctx)
if referred_to_type is None:
# couldn't extract to= value
return fill_typevars_with_any(ctx.default_return_type)
return reparametrize_with(ctx.default_return_type, [referred_to_type])
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
api = ctx.api
new_symtable_nodes = SymbolTable()
for (name, symtable_node), stmt in zip(ctx.cls.info.names.items(), ctx.cls.defs.body):
for stmt in ctx.cls.defs.body:
if not isinstance(stmt, AssignmentStmt):
continue
if not hasattr(stmt.rvalue, 'callee'):
continue
if len(stmt.lvalues) > 1:
# multiple lvalues not supported for now
continue
expr = stmt.lvalues[0]
if not isinstance(expr, NameExpr):
continue
name = expr.name
rvalue_callee = stmt.rvalue.callee
if rvalue_callee.fullname in {helpers.FOREIGN_KEY_FULLNAME,
@@ -108,7 +95,4 @@ def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
new_node = helpers.create_new_symtable_node(name,
kind=MDEF,
instance=api.named_type('__builtins__.int'))
new_symtable_nodes[name] = new_node
for name, node in new_symtable_nodes.items():
ctx.cls.info.names[name] = node
ctx.cls.info.names[name] = new_node