add support for managers as generics

This commit is contained in:
Maxim Kurnikov
2018-12-07 22:11:22 +03:00
parent 94ddb8c864
commit c9ad40d7e3
8 changed files with 237 additions and 187 deletions

View File

@@ -4,11 +4,23 @@ from typing import Dict, Optional
from mypy.nodes import StrExpr, MypyFile, TypeInfo, ImportedName, SymbolNode
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField'
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager'
MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager'
RELATED_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.RelatedManager'
MANAGER_CLASSES = {
MANAGER_CLASS_FULLNAME,
RELATED_MANAGER_CLASS_FULLNAME,
BASE_MANAGER_CLASS_FULLNAME,
QUERYSET_CLASS_FULLNAME
}
def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]:
models_module = '.'.join([app_name, 'models'])

View File

@@ -1,24 +1,51 @@
import os
from typing import Callable, Optional
from typing import Callable, Optional, cast
from mypy.checker import TypeChecker
from mypy.options import Options
from mypy.plugin import Plugin, FunctionContext, ClassDefContext, AnalyzeTypeContext
from mypy.types import Type
from mypy.types import Type, Instance
from mypy.typevars import fill_typevars
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
from mypy_django_plugin.plugins.models import process_model_class
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
from mypy_django_plugin.plugins.settings import DjangoConfSettingsInitializerHook
base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
manager_subclasses = set()
class TransformModelClassHook(object):
def __call__(self, ctx: ClassDefContext) -> None:
base_model_classes.add(ctx.cls.fullname)
process_model_class(ctx)
def transform_model_class(ctx: ClassDefContext) -> None:
base_model_classes.add(ctx.cls.fullname)
process_model_class(ctx)
def add_new_manager_subclass(ctx: ClassDefContext) -> None:
manager_subclasses.add(ctx.cls.fullname)
def determine_proper_manager_type(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
ret = ctx.default_return_type
if not api.tscope.classes:
# not in class
return ret
outer_model_info = api.tscope.classes[0]
if not outer_model_info.has_base(helpers.MODEL_CLASS_FULLNAME):
return ret
if not isinstance(ret, Instance):
return ret
for i, base in enumerate(ret.type.bases):
if base.type.fullname() in {helpers.MANAGER_CLASS_FULLNAME,
helpers.RELATED_MANAGER_CLASS_FULLNAME,
helpers.BASE_MANAGER_CLASS_FULLNAME}:
ret.type.bases[i] = reparametrize_with(base, [Instance(outer_model_info, [])])
return ret
return ret
class DjangoPlugin(Plugin):
@@ -39,21 +66,27 @@ class DjangoPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname in {helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME}:
helpers.ONETOONE_FIELD_FULLNAME,
helpers.MANYTOMANY_FIELD_FULLNAME}:
return extract_to_parameter_as_get_ret_type_for_related_field
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field
if fullname in manager_subclasses:
return determine_proper_manager_type
return None
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in base_model_classes:
return TransformModelClassHook()
return transform_model_class
if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS:
return DjangoConfSettingsInitializerHook(settings_module=self.django_settings)
if fullname in helpers.MANAGER_CLASSES:
return add_new_manager_subclass
return None

View File

@@ -1,3 +1,5 @@
import dataclasses
from abc import abstractmethod, ABCMeta
from typing import cast, Iterator, Tuple, Optional, Dict
from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \
@@ -9,6 +11,47 @@ from mypy.types import Instance
from mypy_django_plugin import helpers
@dataclasses.dataclass
class ModelClassInitializer(metaclass=ABCMeta):
api: SemanticAnalyzerPass2
model_classdef: ClassDef
@classmethod
def from_ctx(cls, ctx: ClassDefContext):
return cls(api=cast(SemanticAnalyzerPass2, ctx.api), model_classdef=ctx.cls)
def get_nested_meta_node(self) -> Optional[TypeInfo]:
metaclass_sym = self.model_classdef.info.names.get('Meta')
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
return metaclass_sym.node
return None
def is_abstract_model(self) -> bool:
meta_node = self.get_nested_meta_node()
if meta_node is None:
return False
for lvalue, rvalue in iter_over_assignments(meta_node.defn):
if isinstance(lvalue, NameExpr) and lvalue.name == 'abstract':
is_abstract = self.api.parse_bool(rvalue)
if is_abstract:
# abstract model do not need 'objects' queryset
return True
return False
def add_new_node_to_model_class(self, name: str, typ: Instance) -> None:
var = Var(name=name, type=typ)
var.info = typ.type
var._fullname = self.model_classdef.info.fullname() + '.' + name
var.is_inferred = True
var.is_initialized_in_class = True
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var)
@abstractmethod
def run(self) -> None:
raise NotImplementedError()
def add_new_var_node_to_class(class_type: TypeInfo, name: str, typ: Instance) -> None:
var = Var(name=name, type=typ)
var.info = typ.type
@@ -37,78 +80,76 @@ def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]:
def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExpr, CallExpr]]:
for lvalue, rvalue in iter_call_assignments(klass):
if (isinstance(lvalue, NameExpr)
and isinstance(rvalue.callee, MemberExpr)):
and isinstance(rvalue.callee, MemberExpr)):
if rvalue.callee.fullname in {helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME}:
yield lvalue, rvalue
def get_nested_meta_class(model_type: TypeInfo) -> Optional[TypeInfo]:
metaclass_sym = model_type.names.get('Meta')
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
return metaclass_sym.node
return None
class SetIdAttrsForRelatedFields(ModelClassInitializer):
def run(self) -> None:
for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef):
self.add_new_node_to_model_class(lvalue.name + '_id',
typ=self.api.named_type('__builtins__.int'))
def is_abstract_model(ctx: ClassDefContext) -> bool:
meta_node = get_nested_meta_class(ctx.cls.info)
if meta_node is None:
return False
for lvalue, rvalue in iter_over_assignments(meta_node.defn):
if isinstance(lvalue, NameExpr) and lvalue.name == 'abstract':
is_abstract = ctx.api.parse_bool(rvalue)
if is_abstract:
# abstract model do not need 'objects' queryset
return True
return False
class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):
def run(self) -> None:
meta_node = self.get_nested_meta_node()
if meta_node is None:
return None
meta_node.fallback_to_any = True
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
api = ctx.api
for lvalue, rvalue in iter_over_one_to_n_related_fields(ctx.cls):
property_name = lvalue.name + '_id'
add_new_var_node_to_class(ctx.cls.info, property_name,
typ=api.named_type('__builtins__.int'))
class AddDefaultObjectsManager(ModelClassInitializer):
def run(self) -> None:
if 'objects' in self.model_classdef.info.names:
return None
if self.is_abstract_model():
# abstract models do not need 'objects' queryset
return None
typ = self.api.named_type_or_none(helpers.MANAGER_CLASS_FULLNAME,
args=[Instance(self.model_classdef.info, [])])
if not typ:
return None
self.add_new_node_to_model_class('objects', typ)
def add_int_id_attribute_if_primary_key_true_is_not_present(ctx: ClassDefContext) -> None:
api = cast(SemanticAnalyzerPass2, ctx.api)
if is_abstract_model(ctx):
return None
class AddIdAttributeIfPrimaryKeyTrueIsNotSet(ModelClassInitializer):
def run(self) -> None:
if self.is_abstract_model():
# no need for .id attr
return None
for _, rvalue in iter_call_assignments(ctx.cls):
if ('primary_key' in rvalue.arg_names and
api.parse_bool(rvalue.args[rvalue.arg_names.index('primary_key')])):
break
else:
add_new_var_node_to_class(ctx.cls.info, 'id', api.builtin_type('builtins.int'))
for _, rvalue in iter_call_assignments(self.model_classdef):
if ('primary_key' in rvalue.arg_names
and self.api.parse_bool(rvalue.args[rvalue.arg_names.index('primary_key')])):
break
else:
self.add_new_node_to_model_class('id', self.api.builtin_type('builtins.int'))
def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None:
# search over mro
objects_sym = ctx.cls.info.get('objects')
if objects_sym is not None:
return None
# only direct Meta class
if is_abstract_model(ctx):
# abstract model do not need 'objects' queryset
return None
api = cast(SemanticAnalyzerPass2, ctx.api)
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(ctx.cls.info, [])])
if not typ:
return None
add_new_var_node_to_class(ctx.cls.info, 'objects', typ=typ)
def inject_any_as_base_for_nested_class_meta(ctx: ClassDefContext) -> None:
meta_node = get_nested_meta_class(ctx.cls.info)
if meta_node is None:
return None
meta_node.fallback_to_any = True
class AddRelatedManagers(ModelClassInitializer):
def run(self) -> None:
for module_name, module_file in self.api.modules.items():
for defn in iter_over_classdefs(module_file):
for lvalue, rvalue in iter_call_assignments(defn):
if is_related_field(rvalue, module_file):
ref_to_fullname = extract_ref_to_fullname(rvalue,
module_file=module_file,
all_modules=self.api.modules)
if self.model_classdef.fullname == ref_to_fullname:
if 'related_name' in rvalue.arg_names:
related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')]
if not isinstance(related_name_expr, StrExpr):
return None
related_name = related_name_expr.value
typ = get_related_field_type(rvalue, self.api, defn.info)
if typ is None:
return None
self.add_new_node_to_model_class(related_name, typ)
def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]:
@@ -119,8 +160,8 @@ def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]:
def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2,
related_model_typ: TypeInfo) -> Optional[Instance]:
if rvalue.callee.name == 'ForeignKey':
return api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME,
if rvalue.callee.name in {'ForeignKey', 'ManyToManyField'}:
return api.named_type_or_none(helpers.RELATED_MANAGER_CLASS_FULLNAME,
args=[Instance(related_model_typ, [])])
else:
return Instance(related_model_typ, [])
@@ -129,7 +170,9 @@ def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2,
def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool:
if isinstance(expr.callee, MemberExpr) and isinstance(expr.callee.expr, NameExpr):
module = module_file.names[expr.callee.expr.name]
if module.fullname == 'django.db.models' and expr.callee.name in {'ForeignKey', 'OneToOneField'}:
if module.fullname == 'django.db.models' and expr.callee.name in {'ForeignKey',
'OneToOneField',
'ManyToManyField'}:
return True
return False
@@ -150,31 +193,16 @@ def extract_ref_to_fullname(rvalue_expr: CallExpr,
return None
def add_related_managers(ctx: ClassDefContext):
api = cast(SemanticAnalyzerPass2, ctx.api)
for module_name, module_file in ctx.api.modules.items():
for defn in iter_over_classdefs(module_file):
for lvalue, rvalue in iter_call_assignments(defn):
if is_related_field(rvalue, module_file):
ref_to_fullname = extract_ref_to_fullname(rvalue, module_file=module_file,
all_modules=api.modules)
if ctx.cls.fullname == ref_to_fullname:
if 'related_name' in rvalue.arg_names:
related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')]
if not isinstance(related_name_expr, StrExpr):
return None
related_name = related_name_expr.value
typ = get_related_field_type(rvalue, api, defn.info)
if typ is None:
return None
add_new_var_node_to_class(ctx.cls.info, related_name, typ)
def process_model_class(ctx: ClassDefContext) -> None:
add_related_managers(ctx)
inject_any_as_base_for_nested_class_meta(ctx)
set_fieldname_attrs_for_related_fields(ctx)
add_int_id_attribute_if_primary_key_true_is_not_present(ctx)
set_objects_queryset_to_model_class(ctx)
initializers = [
InjectAnyAsBaseForNestedMeta,
AddDefaultObjectsManager,
AddIdAttributeIfPrimaryKeyTrueIsNotSet,
SetIdAttrsForRelatedFields,
AddRelatedManagers
]
for initializer_cls in initializers:
initializer_cls.from_ctx(ctx).run()
# allow unspecified attributes for now
ctx.cls.info.fallback_to_any = True