add support for forms, values, values_list

This commit is contained in:
Maxim Kurnikov
2019-07-17 21:12:17 +03:00
parent 3c3122a93f
commit d53121baae
15 changed files with 594 additions and 560 deletions

View File

@@ -1,10 +1,11 @@
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type
from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type, Sequence
from django.core.exceptions import FieldError
from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey
from django.db.models.fields.related import ForeignKey, RelatedField
from django.utils.functional import cached_property
from mypy.checker import TypeChecker
from mypy.types import Instance, Type as MypyType
@@ -13,6 +14,8 @@ from pytest_mypy.utils import temp_environ
from django.contrib.postgres.fields import ArrayField
from django.db.models.fields import CharField, Field
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.sql.query import Query
from mypy_django_plugin_newsemanal.lib import helpers
if TYPE_CHECKING:
@@ -53,6 +56,9 @@ def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
class DjangoFieldsContext:
def __init__(self, django_context: 'DjangoContext') -> None:
self.django_context = django_context
def get_attname(self, field: Field) -> str:
attname = field.attname
return attname
@@ -81,11 +87,43 @@ class DjangoFieldsContext:
field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type)
return field_set_type
def get_field_get_type(self, api: TypeChecker, field: Field, method: str) -> MypyType:
field_info = helpers.lookup_class_typeinfo(api, field.__class__)
is_nullable = self.get_field_nullability(field, method)
if isinstance(field, RelatedField):
if method == 'values':
primary_key_field = self.django_context.get_primary_key_field(field.related_model)
return self.get_field_get_type(api, primary_key_field, method)
model_info = helpers.lookup_class_typeinfo(api, field.related_model)
return Instance(model_info, [])
else:
return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
is_nullable=is_nullable)
class DjangoLookupsContext:
def resolve_lookup(self, model_cls: Type[Model], lookup: str) -> Any:
query = Query(model_cls)
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
if lookup_parts:
raise FieldError('Lookups not supported yet')
currently_observed_model = model_cls
current_field = None
for field_name in field_parts:
current_field = currently_observed_model._meta.get_field(field_name)
if isinstance(current_field, RelatedField):
currently_observed_model = current_field.related_model
return current_field
class DjangoContext:
def __init__(self, plugin_toml_config: Optional[Dict[str, Any]]) -> None:
self.config = DjangoPluginConfig()
self.fields_context = DjangoFieldsContext()
self.fields_context = DjangoFieldsContext(self)
self.lookups_context = DjangoLookupsContext()
self.django_settings_module = None
if plugin_toml_config:

View File

@@ -1,9 +1,17 @@
from typing import Dict, List, Optional, Set, Union
from collections import OrderedDict
from typing import Dict, List, Optional, Set, Union, Any
from mypy import checker
from mypy.checker import TypeChecker
from mypy.nodes import Expression, MypyFile, NameExpr, SymbolNode, TypeInfo, Var, SymbolTableNode
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance, NoneTyp, Type as MypyType, TypeOfAny, UnionType
from mypy.mro import calculate_mro
from mypy.nodes import Block, ClassDef, Expression, GDEF, MDEF, MypyFile, NameExpr, SymbolNode, SymbolTable, SymbolTableNode, \
TypeInfo, Var
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext
from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {})
class IncompleteDefnException(Exception):
@@ -120,6 +128,53 @@ def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]
return None
def add_new_class_for_current_module(api: TypeChecker, name: str, bases: List[Instance],
fields: 'OrderedDict[str, MypyType]') -> TypeInfo:
current_module = api.scope.stack[0]
new_class_unique_name = checker.gen_unique_name(name, current_module.names)
# make new class expression
classdef = ClassDef(new_class_unique_name, Block([]))
classdef.fullname = current_module.fullname() + '.' + new_class_unique_name
# make new TypeInfo
new_typeinfo = TypeInfo(SymbolTable(), classdef, current_module.fullname())
new_typeinfo.bases = bases
calculate_mro(new_typeinfo)
new_typeinfo.calculate_metaclass_type()
def add_field_to_new_typeinfo(var: Var, is_initialized_in_class: bool = False,
is_property: bool = False) -> None:
var.info = new_typeinfo
var.is_initialized_in_class = is_initialized_in_class
var.is_property = is_property
var._fullname = new_typeinfo.fullname() + '.' + var.name()
new_typeinfo.names[var.name()] = SymbolTableNode(MDEF, var)
# add fields
var_items = [Var(item, typ) for item, typ in fields.items()]
for var_item in var_items:
add_field_to_new_typeinfo(var_item, is_property=True)
classdef.info = new_typeinfo
current_module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)
return new_typeinfo
def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType:
namedtuple_info = add_new_class_for_current_module(api, name,
bases=[api.named_generic_type('typing.NamedTuple', [])],
fields=fields)
return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, []))
def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType:
# fallback for tuples is any builtins.tuple instance
fallback = api.named_generic_type('builtins.tuple',
[AnyType(TypeOfAny.special_form)])
return TupleType(fields, fallback=fallback)
def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
if isinstance(typ, UnionType):
converted_items = []
@@ -140,3 +195,10 @@ def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
return referred_to_type
return typ
def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyType]',
required_keys: Set[str]) -> TypedDictType:
object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
return typed_dict_type

View File

@@ -1,27 +0,0 @@
from typing import Any, Dict, List
from mypy.nodes import TypeInfo
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {})
def get_related_field_primary_key_names(base_model: TypeInfo) -> List[str]:
return get_django_metadata(base_model).setdefault('related_field_primary_keys', [])
def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]:
return get_django_metadata(model).setdefault('fields', {})
def get_lookups_metadata(model: TypeInfo) -> Dict[str, Any]:
return get_django_metadata(model).setdefault('lookups', {})
def get_related_managers_metadata(model: TypeInfo) -> Dict[str, Any]:
return get_django_metadata(model).setdefault('related_managers', {})
def get_managers_metadata(model: TypeInfo) -> Dict[str, Any]:
return get_django_metadata(model).setdefault('managers', {})

View File

@@ -1,17 +1,17 @@
import os
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Type
from typing import Callable, Dict, List, Optional, Tuple
import toml
from django.db.models.fields.related import RelatedField
from mypy.nodes import MypyFile, TypeInfo
from mypy.options import Options
from mypy.plugin import ClassDefContext, FunctionContext, Plugin, MethodContext, AttributeContext
from mypy.plugin import AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import Type as MypyType
from django.db.models.fields.related import RelatedField
from mypy_django_plugin_newsemanal.django.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import fullnames, metadata
from mypy_django_plugin_newsemanal.transformers import fields, settings, querysets, init_create
from mypy_django_plugin_newsemanal.lib import fullnames, helpers
from mypy_django_plugin_newsemanal.transformers import fields, forms, init_create, querysets, settings
from mypy_django_plugin_newsemanal.transformers.models import process_model_class
@@ -20,7 +20,7 @@ def transform_model_class(ctx: ClassDefContext,
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MODEL_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
metadata.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1
helpers.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1
else:
if not ctx.api.final_iteration:
ctx.api.defer()
@@ -29,10 +29,18 @@ def transform_model_class(ctx: ClassDefContext,
process_model_class(ctx, django_context)
def transform_form_class(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.BASEFORM_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
helpers.get_django_metadata(sym.node)['baseform_bases'][ctx.cls.fullname] = 1
forms.make_meta_nested_class_inherit_from_any(ctx)
def add_new_manager_base(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
metadata.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1
helpers.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1
class NewSemanalDjangoPlugin(Plugin):
@@ -50,7 +58,7 @@ class NewSemanalDjangoPlugin(Plugin):
def _get_current_queryset_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(fullnames.QUERYSET_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (metadata.get_django_metadata(model_sym.node)
return (helpers.get_django_metadata(model_sym.node)
.setdefault('queryset_bases', {fullnames.QUERYSET_CLASS_FULLNAME: 1}))
else:
return {}
@@ -58,7 +66,7 @@ class NewSemanalDjangoPlugin(Plugin):
def _get_current_manager_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(fullnames.MANAGER_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (metadata.get_django_metadata(model_sym.node)
return (helpers.get_django_metadata(model_sym.node)
.setdefault('manager_bases', {fullnames.MANAGER_CLASS_FULLNAME: 1}))
else:
return {}
@@ -66,8 +74,18 @@ class NewSemanalDjangoPlugin(Plugin):
def _get_current_model_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(fullnames.MODEL_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return metadata.get_django_metadata(model_sym.node).setdefault('model_bases',
{fullnames.MODEL_CLASS_FULLNAME: 1})
return helpers.get_django_metadata(model_sym.node).setdefault('model_bases',
{fullnames.MODEL_CLASS_FULLNAME: 1})
else:
return {}
def _get_current_form_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(fullnames.BASEFORM_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node)
.setdefault('baseform_bases', {fullnames.BASEFORM_CLASS_FULLNAME: 1,
fullnames.FORM_CLASS_FULLNAME: 1,
fullnames.MODELFORM_CLASS_FULLNAME: 1}))
else:
return {}
@@ -85,15 +103,20 @@ class NewSemanalDjangoPlugin(Plugin):
if file.fullname() == 'django.conf' and self.django_context.django_settings_module:
return [self._new_dependency(self.django_context.django_settings_module)]
# for values / values_list
if file.fullname() == 'django.db.models':
return [self._new_dependency('mypy_extensions'), self._new_dependency('typing')]
# for `get_user_model()`
if file.fullname() == 'django.contrib.auth':
auth_user_model_name = self.django_context.settings.AUTH_USER_MODEL
try:
auth_user_module = self.django_context.apps_registry.get_model(auth_user_model_name).__module__
except LookupError:
# get_user_model() model app is not installed
return []
return [self._new_dependency(auth_user_module)]
if self.django_context.settings:
if file.fullname() == 'django.contrib.auth':
auth_user_model_name = self.django_context.settings.AUTH_USER_MODEL
try:
auth_user_module = self.django_context.apps_registry.get_model(auth_user_model_name).__module__
except LookupError:
# get_user_model() model app is not installed
return []
return [self._new_dependency(auth_user_module)]
# ensure that all mentioned to='someapp.SomeModel' are loaded with corresponding related Fields
defined_model_classes = self.django_context.model_modules.get(file.fullname())
@@ -132,9 +155,29 @@ class NewSemanalDjangoPlugin(Plugin):
return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)
def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]:
manager_classes = self._get_current_manager_bases()
) -> Optional[Callable[[MethodContext], MypyType]]:
class_fullname, _, method_name = fullname.rpartition('.')
if method_name == 'get_form_class':
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return forms.extract_proper_type_for_get_form_class
if method_name == 'get_form':
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return forms.extract_proper_type_for_get_form
if method_name == 'values':
model_info = self._get_typeinfo_or_none(class_fullname)
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
if method_name == 'values_list':
model_info = self._get_typeinfo_or_none(class_fullname)
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
manager_classes = self._get_current_manager_bases()
if class_fullname in manager_classes and method_name == 'create':
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
@@ -146,6 +189,9 @@ class NewSemanalDjangoPlugin(Plugin):
if fullname in self._get_current_manager_bases():
return add_new_manager_base
if fullname in self._get_current_form_bases():
return transform_form_class
def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], MypyType]]:
class_name, _, attr_name = fullname.rpartition('.')
@@ -153,12 +199,6 @@ class NewSemanalDjangoPlugin(Plugin):
return partial(settings.get_type_of_settings_attribute,
django_context=self.django_context)
# def get_type_analyze_hook(self, fullname: str
# ) -> Optional[Callable[[AnalyzeTypeContext], MypyType]]:
# queryset_bases = self._get_current_queryset_bases()
# if fullname in queryset_bases:
# return partial(querysets.set_first_generic_param_as_default_for_second, fullname=fullname)
def plugin(version):
return NewSemanalDjangoPlugin

View File

@@ -0,0 +1,50 @@
from typing import Optional
from mypy.plugin import ClassDefContext, MethodContext
from mypy.types import CallableType, Instance, NoneTyp, Type as MypyType, TypeType
from mypy_django_plugin.lib import helpers
def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None:
meta_node = helpers.get_nested_meta_node_for_current_class(ctx.cls.info)
if meta_node is None:
if not ctx.api.final_iteration:
ctx.api.defer()
else:
meta_node.fallback_to_any = True
def get_specified_form_class(object_type: Instance) -> Optional[TypeType]:
form_class_sym = object_type.type.get('form_class')
if form_class_sym and isinstance(form_class_sym.type, CallableType):
return TypeType(form_class_sym.type.ret_type)
return None
def extract_proper_type_for_get_form(ctx: MethodContext) -> MypyType:
object_type = ctx.type
assert isinstance(object_type, Instance)
form_class_type = helpers.get_call_argument_type_by_name(ctx, 'form_class')
if form_class_type is None or isinstance(form_class_type, NoneTyp):
form_class_type = get_specified_form_class(object_type)
if isinstance(form_class_type, TypeType) and isinstance(form_class_type.item, Instance):
return form_class_type.item
if isinstance(form_class_type, CallableType) and isinstance(form_class_type.ret_type, Instance):
return form_class_type.ret_type
return ctx.default_return_type
def extract_proper_type_for_get_form_class(ctx: MethodContext) -> MypyType:
object_type = ctx.type
assert isinstance(object_type, Instance)
form_class_type = get_specified_form_class(object_type)
if form_class_type is None:
return ctx.default_return_type
return form_class_type

View File

@@ -3,15 +3,14 @@ from abc import ABCMeta, abstractmethod
from typing import cast
from django.db.models.fields.related import ForeignKey
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
from mypy.newsemanal.semanal import NewSemanticAnalyzer
from mypy.nodes import ClassDef, MDEF, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext
from mypy.types import Instance
from django.db.models.fields import Field
from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel, ManyToManyRel
from mypy_django_plugin_newsemanal.django.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import helpers, fullnames
from mypy_django_plugin_newsemanal.lib import fullnames, helpers
from mypy_django_plugin_newsemanal.transformers import fields
from mypy_django_plugin_newsemanal.transformers.fields import get_field_descriptor_types

View File

@@ -1,6 +1,14 @@
from mypy.plugin import AnalyzeTypeContext, FunctionContext
from collections import OrderedDict
from typing import Optional, Tuple, Type
from django.core.exceptions import FieldError
from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey
from mypy.nodes import NameExpr
from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext
from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny
from mypy_django_plugin_newsemanal.django.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import fullnames, helpers
@@ -37,3 +45,122 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
return ret
return helpers.reparametrize_instance(ret, [Instance(outer_model_info, [])])
def get_lookup_field_get_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
lookup: str, method: str) -> Optional[Tuple[str, MypyType]]:
try:
lookup_field = django_context.lookups_context.resolve_lookup(model_cls, lookup)
except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context)
return None
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, method)
return lookup_field.attname, field_get_type
def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
flat: bool, named: bool) -> MypyType:
field_lookups = [expr.value for expr in ctx.args[0]]
if len(field_lookups) == 0:
if flat:
primary_key_field = django_context.get_primary_key_field(model_cls)
_, field_get_type = get_lookup_field_get_type(ctx, django_context, model_cls,
primary_key_field.attname, 'values_list')
return field_get_type
elif named:
column_types = OrderedDict()
for field in django_context.get_model_fields(model_cls):
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, field, 'values_list')
column_types[field.attname] = field_get_type
return helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types)
else:
# flat=False, named=False, all fields
field_lookups = []
for field in model_cls._meta.get_fields():
field_lookups.append(field.attname)
if len(field_lookups) > 1 and flat:
ctx.api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context)
return AnyType(TypeOfAny.from_error)
column_types = OrderedDict()
for field_lookup in field_lookups:
result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values_list')
if result is None:
return AnyType(TypeOfAny.from_error)
field_name, field_get_type = result
column_types[field_name] = field_get_type
if flat:
assert len(column_types) == 1
row_type = next(iter(column_types.values()))
elif named:
row_type = helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types)
else:
row_type = helpers.make_tuple(ctx.api, list(column_types.values()))
return row_type
def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
assert isinstance(ctx.type, Instance)
assert isinstance(ctx.type.args[0], Instance)
model_type = ctx.type.args[0]
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
if model_cls is None:
return ctx.default_return_type
flat_expr = helpers.get_call_argument_by_name(ctx, 'flat')
if flat_expr is not None and isinstance(flat_expr, NameExpr):
flat = helpers.parse_bool(flat_expr)
else:
flat = False
named_expr = helpers.get_call_argument_by_name(ctx, 'named')
if named_expr is not None and isinstance(named_expr, NameExpr):
named = helpers.parse_bool(named_expr)
else:
named = False
if flat and named:
ctx.api.fail("'flat' and 'named' can't be used together", ctx.context)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
row_type = get_values_list_row_type(ctx, django_context, model_cls,
flat=flat, named=named)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
assert isinstance(ctx.type, Instance)
assert isinstance(ctx.type.args[0], Instance)
model_type = ctx.type.args[0]
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
if model_cls is None:
return ctx.default_return_type
field_lookups = [expr.value for expr in ctx.args[0]]
if len(field_lookups) == 0:
for field in model_cls._meta.get_fields():
field_lookups.append(field.attname)
column_types = OrderedDict()
for field_lookup in field_lookups:
try:
lookup_field = django_context.lookups_context.resolve_lookup(model_cls, field_lookup)
except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, 'values')
field_name = lookup_field.attname
if isinstance(lookup_field, ForeignKey) and field_lookup == lookup_field.name:
field_name = lookup_field.name
column_types[field_name] = field_get_type
row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()))
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])