add proper __init__, create() support

This commit is contained in:
Maxim Kurnikov
2019-07-16 16:49:49 +03:00
parent b11a9a85f9
commit 2cb1f257eb
26 changed files with 306 additions and 463 deletions
@@ -1,12 +1,12 @@
from typing import Optional, Tuple, cast
from mypy.checker import TypeChecker
from mypy.nodes import Expression, ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo
from mypy.nodes import StrExpr, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, TupleType, Type as MypyType, UnionType
from mypy.types import AnyType, CallableType, Instance, Type as MypyType, UnionType
from mypy_django_plugin_newsemanal.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import fullnames, helpers, metadata
from mypy_django_plugin_newsemanal.django.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import fullnames, helpers
def extract_referred_to_type(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Instance]:
@@ -1,35 +1,69 @@
from typing import cast
from typing import List, Tuple, Type, Union
from mypy.checker import TypeChecker
from mypy.nodes import Argument, Var, ARG_NAMED
from mypy.plugin import FunctionContext
from mypy.types import Type as MypyType, Instance
from django.db.models.base import Model
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import Instance, Type as MypyType
from mypy_django_plugin_newsemanal.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import helpers
from mypy_django_plugin_newsemanal.django.context import DjangoContext
def get_actual_types(ctx: Union[MethodContext, FunctionContext], expected_keys: List[str]) -> List[Tuple[str, MypyType]]:
actual_types = []
# positionals
for pos, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
if actual_name is None:
if ctx.callee_arg_names[0] == 'kwargs':
# unpacked dict as kwargs is not supported
continue
actual_name = expected_keys[pos]
actual_types.append((actual_name, actual_type))
# kwargs
if len(ctx.callee_arg_names) > 1:
for actual_name, actual_type in zip(ctx.arg_names[1], ctx.arg_types[1]):
if actual_name is None:
# unpacked dict as kwargs is not supported
continue
actual_types.append((actual_name, actual_type))
return actual_types
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
model_cls: Type[Model], method: str) -> MypyType:
expected_types = django_context.get_expected_types(ctx.api, model_cls, method)
expected_keys = [key for key in expected_types.keys() if key != 'pk']
for actual_name, actual_type in get_actual_types(ctx, expected_keys):
if actual_name not in expected_types:
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model_cls.__name__),
ctx.context)
continue
ctx.api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model_cls.__name__),
'got', 'expected')
return ctx.default_return_type
def redefine_and_typecheck_model_init(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
assert isinstance(ctx.default_return_type, Instance)
api = cast(TypeChecker, ctx.api)
model_fullname = ctx.default_return_type.type.fullname()
model_cls = django_context.get_model_class_by_fullname(model_fullname)
if model_cls is None:
return ctx.default_return_type
model_info = ctx.default_return_type.type
model_cls = django_context.get_model_class_by_fullname(model_info.fullname())
return typecheck_model_method(ctx, django_context, model_cls, '__init__')
# expected_types = {}
# for field in model_cls._meta.get_fields():
# field_fullname = helpers.get_class_fullname(field.__class__)
# field_info = api.lookup_typeinfo(field_fullname)
# field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
# is_nullable=False)
# field_kwarg = Argument(variable=Var(field.attname, field_set_type),
# type_annotation=field_set_type,
# initializer=None,
# kind=ARG_NAMED)
# expected_types[field.attname] = field_set_type
# for field_name, field in model_cls._meta.fields_map.items():
# print()
# print()
return ctx.default_return_type
def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
isinstance(ctx.default_return_type, Instance)
model_fullname = ctx.default_return_type.type.fullname()
model_cls = django_context.get_model_class_by_fullname(model_fullname)
if model_cls is None:
return ctx.default_return_type
return typecheck_model_method(ctx, django_context, model_cls, 'create')
@@ -1,18 +1,15 @@
import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Optional, Type, cast
from typing import cast
from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey
from mypy.newsemanal.semanal import NewSemanticAnalyzer
from mypy.nodes import ARG_NAMED_OPT, Argument, ClassDef, MDEF, SymbolTableNode, TypeInfo, Var
from mypy.nodes import ClassDef, MDEF, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext
from mypy.plugins import common
from mypy.types import AnyType, Instance, NoneType, Type as MypyType, UnionType
from mypy.types import Instance
from django.contrib.postgres.fields import ArrayField
from django.db.models.fields import CharField, Field
from mypy_django_plugin_newsemanal.context import DjangoContext
from django.db.models.fields import Field
from mypy_django_plugin_newsemanal.django.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import helpers
from mypy_django_plugin_newsemanal.transformers import fields
from mypy_django_plugin_newsemanal.transformers.fields import get_field_descriptor_types
@@ -52,101 +49,6 @@ class ModelClassInitializer(metaclass=ABCMeta):
var.is_initialized_in_class = True
var.is_inferred = True
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True)
# assert self.model_classdef.info == self.api.type
# self.api.add_symbol_table_node(name, SymbolTableNode(MDEF, var, plugin_generated=True))
def convert_any_to_type(self, typ: MypyType, referred_to_type: MypyType) -> MypyType:
if isinstance(typ, UnionType):
converted_items = []
for item in typ.items:
converted_items.append(self.convert_any_to_type(item, referred_to_type))
return UnionType.make_union(converted_items,
line=typ.line, column=typ.column)
if isinstance(typ, Instance):
args = []
for default_arg in typ.args:
if isinstance(default_arg, AnyType):
args.append(referred_to_type)
else:
args.append(default_arg)
return helpers.reparametrize_instance(typ, args)
if isinstance(typ, AnyType):
return referred_to_type
return typ
def get_field_set_type(self, field: Field, method: str) -> MypyType:
target_field = field
if isinstance(field, ForeignKey):
target_field = field.target_field
field_fullname = helpers.get_class_fullname(target_field.__class__)
field_info = self.lookup_typeinfo_or_incomplete_defn_error(field_fullname)
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
is_nullable=self.get_field_nullability(field, method))
if isinstance(target_field, ArrayField):
argument_field_type = self.get_field_set_type(target_field.base_field, method)
field_set_type = self.convert_any_to_type(field_set_type, argument_field_type)
return field_set_type
def get_field_nullability(self, field: Field, method: Optional[str]) -> bool:
nullable = field.null
if not nullable and isinstance(field, CharField) and field.blank:
return True
if method == '__init__':
if field.primary_key or isinstance(field, ForeignKey):
return True
return nullable
def get_field_kind(self, field: Field, method: str):
if method == '__init__':
# all arguments are optional in __init__
return ARG_NAMED_OPT
def get_primary_key_field(self, model_cls: Type[Model]) -> Field:
for field in model_cls._meta.get_fields():
if isinstance(field, Field):
if field.primary_key:
return field
raise ValueError('No primary key defined')
def make_field_kwarg(self, name: str, field: Field, method: str) -> Argument:
field_set_type = self.get_field_set_type(field, method)
kind = self.get_field_kind(field, method)
field_kwarg = Argument(variable=Var(name, field_set_type),
type_annotation=field_set_type,
initializer=None,
kind=kind)
return field_kwarg
def get_field_kwargs(self, model_cls: Type[Model], method: str):
field_kwargs = []
if method == '__init__':
# add primary key `pk`
primary_key_field = self.get_primary_key_field(model_cls)
field_kwarg = self.make_field_kwarg('pk', primary_key_field, method)
field_kwargs.append(field_kwarg)
for field in model_cls._meta.get_fields():
if isinstance(field, Field):
field_kwarg = self.make_field_kwarg(field.attname, field, method)
field_kwargs.append(field_kwarg)
if isinstance(field, ForeignKey):
attname = field.name
related_model_fullname = helpers.get_class_fullname(field.related_model)
model_info = self.lookup_typeinfo_or_incomplete_defn_error(related_model_fullname)
is_nullable = self.get_field_nullability(field, method)
field_set_type = Instance(model_info, [])
if is_nullable:
field_set_type = helpers.make_optional(field_set_type)
kind = self.get_field_kind(field, method)
field_kwarg = Argument(variable=Var(attname, field_set_type),
type_annotation=field_set_type,
initializer=None,
kind=kind)
field_kwargs.append(field_kwarg)
return field_kwargs
@abstractmethod
def run(self) -> None:
@@ -198,9 +100,9 @@ class AddRelatedModelsId(ModelClassInitializer):
for field in model_cls._meta.get_fields():
if isinstance(field, ForeignKey):
rel_primary_key_field = self.get_primary_key_field(field.related_model)
rel_primary_key_field = self.django_context.get_primary_key_field(field.related_model)
field_info = self.lookup_field_typeinfo_or_incomplete_defn_error(rel_primary_key_field)
is_nullable = self.get_field_nullability(field, None)
is_nullable = self.django_context.fields_context.get_field_nullability(field, None)
set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
self.add_new_node_to_model_class(field.attname,
Instance(field_info, [set_type, get_type]))
@@ -228,16 +130,6 @@ class AddManagers(ModelClassInitializer):
self.add_new_node_to_model_class('_default_manager', default_manager)
class AddInitMethod(ModelClassInitializer):
def run(self):
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.info.fullname())
if model_cls is None:
return
field_kwargs = self.get_field_kwargs(model_cls, '__init__')
common.add_method(self.ctx, '__init__', field_kwargs, NoneType())
def process_model_class(ctx: ClassDefContext,
django_context: DjangoContext) -> None:
initializers = [
@@ -245,11 +137,10 @@ def process_model_class(ctx: ClassDefContext,
AddDefaultPrimaryKey,
AddRelatedModelsId,
AddManagers,
AddInitMethod
]
for initializer_cls in initializers:
try:
initializer_cls.from_ctx(ctx, django_context).run()
except helpers.IncompleteDefnException:
if not ctx.api.final_iteration:
ctx.api.defer()
ctx.api.defer()
@@ -2,7 +2,7 @@ from mypy.nodes import TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import Type as MypyType, TypeType, Instance
from mypy_django_plugin_newsemanal.context import DjangoContext
from mypy_django_plugin_newsemanal.django.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import helpers