From ade48b6546773eab3986d941f3ae3bf677f15f49 Mon Sep 17 00:00:00 2001 From: Maksim Kurnikov Date: Thu, 12 Dec 2019 05:35:56 +0300 Subject: [PATCH] Add support for BaseManager.from_queryset() (#251) * add support for BaseManager.from_queryset() * cleanups * lint fixes --- mypy_django_plugin/lib/helpers.py | 65 +++++-- mypy_django_plugin/main.py | 14 +- mypy_django_plugin/transformers/managers.py | 71 ++++++++ mypy_django_plugin/transformers/models.py | 162 ++++++++++-------- .../managers/querysets/test_from_queryset.yml | 52 ++++++ 5 files changed, 273 insertions(+), 91 deletions(-) create mode 100644 mypy_django_plugin/transformers/managers.py create mode 100644 test-data/typecheck/managers/querysets/test_from_queryset.yml diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index af99bf6..0a54330 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -1,6 +1,6 @@ from collections import OrderedDict from typing import ( - TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Union, cast, + TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, ) from django.db.models.fields import Field @@ -10,13 +10,15 @@ from mypy import checker from mypy.checker import TypeChecker from mypy.mro import calculate_mro from mypy.nodes import ( - GDEF, MDEF, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, - SymbolTableNode, TypeInfo, Var, + GDEF, MDEF, Argument, Block, ClassDef, Expression, FuncDef, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, + SymbolTable, SymbolTableNode, TypeInfo, Var, ) from mypy.plugin import ( - AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext, + AttributeContext, CheckerPluginInterface, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext, ) -from mypy.types import AnyType, Instance, NoneTyp, TupleType +from mypy.plugins.common import add_method +from mypy.semanal import SemanticAnalyzer +from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType from mypy.types import Type as MypyType from mypy.types import TypedDictType, TypeOfAny, UnionType @@ -55,7 +57,7 @@ def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) return sym.node -def lookup_fully_qualified_typeinfo(api: TypeChecker, fullname: str) -> Optional[TypeInfo]: +def lookup_fully_qualified_typeinfo(api: Union[TypeChecker, SemanticAnalyzer], fullname: str) -> Optional[TypeInfo]: node = lookup_fully_qualified_generic(fullname, api.modules) if not isinstance(node, TypeInfo): return None @@ -173,8 +175,11 @@ def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo] return None -def add_new_class_for_module(module: MypyFile, name: str, bases: List[Instance], - fields: 'OrderedDict[str, MypyType]') -> TypeInfo: +def add_new_class_for_module(module: MypyFile, + name: str, + bases: List[Instance], + fields: Optional[Dict[str, MypyType]] = None + ) -> TypeInfo: new_class_unique_name = checker.gen_unique_name(name, module.names) # make new class expression @@ -188,11 +193,12 @@ def add_new_class_for_module(module: MypyFile, name: str, bases: List[Instance], new_typeinfo.calculate_metaclass_type() # add fields - for field_name, field_type in fields.items(): - var = Var(field_name, type=field_type) - var.info = new_typeinfo - var._fullname = new_typeinfo.fullname + '.' + field_name - new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True) + if fields: + for field_name, field_type in fields.items(): + var = Var(field_name, type=field_type) + var.info = new_typeinfo + var._fullname = new_typeinfo.fullname + '.' + field_name + new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True) classdef.info = new_typeinfo module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True) @@ -269,10 +275,16 @@ def resolve_string_attribute_value(attr_expr: Expression, ctx: Union[FunctionCon return None +def get_semanal_api(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> SemanticAnalyzer: + if not isinstance(ctx.api, SemanticAnalyzer): + raise ValueError('Not a SemanticAnalyzer') + return ctx.api + + def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker: if not isinstance(ctx.api, TypeChecker): raise ValueError('Not a TypeChecker') - return cast(TypeChecker, ctx.api) + return ctx.api def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool: @@ -298,3 +310,28 @@ def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> No var.is_inferred = True info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True) + + +def _prepare_new_method_arguments(node: FuncDef) -> Tuple[List[Argument], MypyType]: + arguments = [] + for argument in node.arguments[1:]: + if argument.type_annotation is None: + argument.type_annotation = AnyType(TypeOfAny.unannotated) + arguments.append(argument) + + if isinstance(node.type, CallableType): + return_type = node.type.ret_type + else: + return_type = AnyType(TypeOfAny.unannotated) + + return arguments, return_type + + +def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance, + new_method_name: str, method_node: FuncDef) -> None: + arguments, return_type = _prepare_new_method_arguments(method_node) + add_method(ctx, + new_method_name, + args=arguments, + return_type=return_type, + self_type=self_type) diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 3ea89be..30ac0e0 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -7,7 +7,7 @@ from mypy.errors import Errors from mypy.nodes import MypyFile, TypeInfo from mypy.options import Options from mypy.plugin import ( - AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin, + AttributeContext, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext, Plugin, ) from mypy.types import Type as MypyType @@ -17,6 +17,9 @@ from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.transformers import ( fields, forms, init_create, meta, querysets, request, settings, ) +from mypy_django_plugin.transformers.managers import ( + create_new_manager_class_from_from_queryset_method, +) from mypy_django_plugin.transformers.models import process_model_class @@ -242,6 +245,15 @@ class NewSemanalDjangoPlugin(Plugin): return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context) return None + def get_dynamic_class_hook(self, fullname: str + ) -> Optional[Callable[[DynamicClassDefContext], None]]: + if fullname.endswith('from_queryset'): + class_name, _, _ = fullname.rpartition('.') + info = self._get_typeinfo_or_none(class_name) + if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME): + return create_new_manager_class_from_from_queryset_method + return None + def plugin(version): return NewSemanalDjangoPlugin diff --git a/mypy_django_plugin/transformers/managers.py b/mypy_django_plugin/transformers/managers.py new file mode 100644 index 0000000..524b7d1 --- /dev/null +++ b/mypy_django_plugin/transformers/managers.py @@ -0,0 +1,71 @@ +from mypy.nodes import ( + GDEF, FuncDef, MemberExpr, NameExpr, StrExpr, SymbolTableNode, TypeInfo, +) +from mypy.plugin import ClassDefContext, DynamicClassDefContext +from mypy.types import AnyType, Instance, TypeOfAny + +from mypy_django_plugin.lib import helpers + + +def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefContext) -> None: + semanal_api = helpers.get_semanal_api(ctx) + + assert isinstance(ctx.call.callee, MemberExpr) + assert isinstance(ctx.call.callee.expr, NameExpr) + base_manager_info = ctx.call.callee.expr.node + if base_manager_info is None: + if not semanal_api.final_iteration: + semanal_api.defer() + return + + assert isinstance(base_manager_info, TypeInfo) + new_manager_info = semanal_api.basic_new_typeinfo(ctx.name, + basetype_or_fallback=Instance(base_manager_info, + [AnyType(TypeOfAny.unannotated)])) + new_manager_info.line = ctx.call.line + new_manager_info.defn.line = ctx.call.line + new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type() + + current_module = semanal_api.cur_mod_node + current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, + plugin_generated=True) + passed_queryset = ctx.call.args[0] + assert isinstance(passed_queryset, NameExpr) + + derived_queryset_fullname = passed_queryset.fullname + assert derived_queryset_fullname is not None + + sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname) + assert sym is not None + if sym.node is None: + if not semanal_api.final_iteration: + semanal_api.defer() + else: + # inherit from Any to prevent false-positives, if queryset class cannot be resolved + new_manager_info.fallback_to_any = True + return + + derived_queryset_info = sym.node + assert isinstance(derived_queryset_info, TypeInfo) + + if len(ctx.call.args) > 1: + expr = ctx.call.args[1] + assert isinstance(expr, StrExpr) + custom_manager_generated_name = expr.value + else: + custom_manager_generated_name = base_manager_info.name + 'From' + derived_queryset_info.name + + custom_manager_generated_fullname = '.'.join(['django.db.models.manager', custom_manager_generated_name]) + if 'from_queryset_managers' not in base_manager_info.metadata: + base_manager_info.metadata['from_queryset_managers'] = {} + base_manager_info.metadata['from_queryset_managers'][custom_manager_generated_fullname] = new_manager_info.fullname + + class_def_context = ClassDefContext(cls=new_manager_info.defn, + reason=ctx.call, api=semanal_api) + self_type = Instance(new_manager_info, []) + for name, sym in derived_queryset_info.names.items(): + if isinstance(sym.node, FuncDef): + helpers.copy_method_to_another_class(class_def_context, + self_type, + new_method_name=name, + method_node=sym.node) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 37d204d..fc7ed39 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -1,5 +1,4 @@ -from collections import OrderedDict -from typing import List, Tuple, Type +from typing import Dict, Optional, Type, cast from django.db.models.base import Model from django.db.models.fields import DateField, DateTimeField @@ -8,10 +7,10 @@ from django.db.models.fields.reverse_related import ( ManyToManyRel, ManyToOneRel, OneToOneRel, ) from mypy.nodes import ARG_STAR2, Argument, Context, FuncDef, TypeInfo, Var -from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface +from mypy.plugin import ClassDefContext from mypy.plugins import common -from mypy.plugins.common import add_method -from mypy.types import AnyType, CallableType, Instance +from mypy.semanal import SemanticAnalyzer +from mypy.types import AnyType, Instance from mypy.types import Type as MypyType from mypy.types import TypeOfAny @@ -22,19 +21,22 @@ from mypy_django_plugin.transformers.fields import get_field_descriptor_types class ModelClassInitializer: - api: SemanticAnalyzerPluginInterface + api: SemanticAnalyzer def __init__(self, ctx: ClassDefContext, django_context: DjangoContext): - self.api = ctx.api + self.api = cast(SemanticAnalyzer, ctx.api) self.model_classdef = ctx.cls self.django_context = django_context self.ctx = ctx + def lookup_typeinfo(self, fullname: str) -> Optional[TypeInfo]: + return helpers.lookup_fully_qualified_typeinfo(self.api, fullname) + def lookup_typeinfo_or_incomplete_defn_error(self, fullname: str) -> TypeInfo: - sym = self.api.lookup_fully_qualified_or_none(fullname) - if sym is None or not isinstance(sym.node, TypeInfo): + info = self.lookup_typeinfo(fullname) + if info is None: raise helpers.IncompleteDefnException(f'No {fullname!r} found') - return sym.node + return info def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInfo: fullname = helpers.get_class_fullname(klass) @@ -135,11 +137,64 @@ class AddRelatedModelsId(ModelClassInitializer): class AddManagers(ModelClassInitializer): - def _is_manager_any(self, typ: Instance) -> bool: - return typ.type.fullname == fullnames.MANAGER_CLASS_FULLNAME and type(typ.args[0]) == AnyType + def has_any_parametrized_manager_as_base(self, info: TypeInfo) -> bool: + for base in helpers.iter_bases(info): + if self.is_any_parametrized_manager(base): + return True + return False + + def is_any_parametrized_manager(self, typ: Instance) -> bool: + return typ.type.fullname == fullnames.MANAGER_CLASS_FULLNAME and isinstance(typ.args[0], AnyType) + + def get_generated_manager_mappings(self, base_manager_fullname: str) -> Dict[str, str]: + base_manager_info = self.lookup_typeinfo(base_manager_fullname) + if (base_manager_info is None + or 'from_queryset_managers' not in base_manager_info.metadata): + return {} + return base_manager_info.metadata['from_queryset_managers'] + + def create_new_model_parametrized_manager(self, name: str, base_manager_info: TypeInfo) -> Instance: + bases = [] + for original_base in base_manager_info.bases: + if self.is_any_parametrized_manager(original_base): + if original_base.type is None: + raise helpers.IncompleteDefnException() + + original_base = helpers.reparametrize_instance(original_base, + [Instance(self.model_classdef.info, [])]) + bases.append(original_base) + + current_module = self.api.modules[self.model_classdef.info.module_name] + custom_manager_info = helpers.add_new_class_for_module(current_module, + name=name, bases=bases) + # copy fields to a new manager + new_cls_def_context = ClassDefContext(cls=custom_manager_info.defn, + reason=self.ctx.reason, + api=self.api) + custom_manager_type = Instance(custom_manager_info, [Instance(self.model_classdef.info, [])]) + + for name, sym in base_manager_info.names.items(): + # replace self type with new class, if copying method + if isinstance(sym.node, FuncDef): + helpers.copy_method_to_another_class(new_cls_def_context, + self_type=custom_manager_type, + new_method_name=name, + method_node=sym.node) + continue + + new_sym = sym.copy() + if isinstance(new_sym.node, Var): + new_var = Var(name, type=sym.type) + new_var.info = custom_manager_info + new_var._fullname = custom_manager_info.fullname + '.' + name + new_sym.node = new_var + custom_manager_info.names[name] = new_sym + + return custom_manager_type def run_with_model_cls(self, model_cls: Type[Model]) -> None: for manager_name, manager in model_cls._meta.managers_map.items(): + manager_class_name = manager.__class__.__name__ manager_fullname = helpers.get_class_fullname(manager.__class__) try: manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname) @@ -147,78 +202,33 @@ class AddManagers(ModelClassInitializer): if not self.api.final_iteration: raise exc else: - continue + base_manager_fullname = helpers.get_class_fullname(manager.__class__.__bases__[0]) + generated_managers = self.get_generated_manager_mappings(base_manager_fullname) + if manager_fullname not in generated_managers: + # not a generated manager, continue with the loop + continue + real_manager_fullname = generated_managers[manager_fullname] + manager_info = self.lookup_typeinfo(real_manager_fullname) # type: ignore + if manager_info is None: + continue + manager_class_name = real_manager_fullname.rsplit('.', maxsplit=1)[1] if manager_name not in self.model_classdef.info.names: manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])]) self.add_new_node_to_model_class(manager_name, manager_type) else: # creates new MODELNAME_MANAGERCLASSNAME class that represents manager parametrized with current model - has_manager_any_base = any(self._is_manager_any(base) for base in manager_info.bases) - if has_manager_any_base: - custom_model_manager_name = manager.model.__name__ + '_' + manager.__class__.__name__ + if not self.has_any_parametrized_manager_as_base(manager_info): + continue - try: - bases = [] - for original_base in manager_info.bases: - if self._is_manager_any(original_base): - if original_base.type is None: - raise helpers.IncompleteDefnException() + custom_model_manager_name = manager.model.__name__ + '_' + manager_class_name + try: + custom_manager_type = self.create_new_model_parametrized_manager(custom_model_manager_name, + base_manager_info=manager_info) + except helpers.IncompleteDefnException: + continue - original_base = helpers.reparametrize_instance(original_base, - [Instance(self.model_classdef.info, [])]) - bases.append(original_base) - except helpers.IncompleteDefnException as exc: - if not self.api.final_iteration: - raise exc - else: - continue - - current_module = self.api.modules[self.model_classdef.info.module_name] - custom_manager_info = helpers.add_new_class_for_module(current_module, - custom_model_manager_name, - bases=bases, - fields=OrderedDict()) - # copy fields to a new manager - new_cls_def_context = ClassDefContext(cls=custom_manager_info.defn, - reason=self.ctx.reason, - api=self.api) - custom_manager_type = Instance(custom_manager_info, [Instance(self.model_classdef.info, [])]) - - for name, sym in manager_info.names.items(): - # replace self type with new class, if copying method - if isinstance(sym.node, FuncDef): - arguments, return_type = self.prepare_new_method_arguments(sym.node) - add_method(new_cls_def_context, - name, - args=arguments, - return_type=return_type, - self_type=custom_manager_type) - continue - - new_sym = sym.copy() - if isinstance(new_sym.node, Var): - new_var = Var(name, type=sym.type) - new_var.info = custom_manager_info - new_var._fullname = custom_manager_info.fullname + '.' + name - new_sym.node = new_var - custom_manager_info.names[name] = new_sym - - self.add_new_node_to_model_class(manager_name, custom_manager_type) - - def prepare_new_method_arguments(self, node: FuncDef) -> Tuple[List[Argument], MypyType]: - arguments = [] - for argument in node.arguments[1:]: - if argument.type_annotation is None: - argument.type_annotation = AnyType(TypeOfAny.unannotated) - arguments.append(argument) - - if isinstance(node.type, CallableType): - return_type = node.type.ret_type - else: - return_type = AnyType(TypeOfAny.unannotated) - - return arguments, return_type + self.add_new_node_to_model_class(manager_name, custom_manager_type) class AddDefaultManagerAttribute(ModelClassInitializer): diff --git a/test-data/typecheck/managers/querysets/test_from_queryset.yml b/test-data/typecheck/managers/querysets/test_from_queryset.yml new file mode 100644 index 0000000..136802a --- /dev/null +++ b/test-data/typecheck/managers/querysets/test_from_queryset.yml @@ -0,0 +1,52 @@ +- case: test_from_queryset_returns_intersection_of_manager_and_queryset + main: | + from myapp.models import MyModel, NewManager + reveal_type(NewManager()) # N: Revealed type is 'myapp.models.NewManager' + reveal_type(MyModel.objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]' + reveal_type(MyModel.objects.get()) # N: Revealed type is 'Any' + reveal_type(MyModel.objects.manager_only_method()) # N: Revealed type is 'builtins.int' + reveal_type(MyModel.objects.manager_and_queryset_method()) # N: Revealed type is 'builtins.str' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class ModelBaseManager(models.Manager): + def manager_only_method(self) -> int: + return 1 + class ModelQuerySet(models.QuerySet): + def manager_and_queryset_method(self) -> str: + return 'hello' + + NewManager = ModelBaseManager.from_queryset(ModelQuerySet) + class MyModel(models.Model): + objects = NewManager() + +- case: test_from_queryset_with_class_name_provided + main: | + from myapp.models import MyModel, NewManager + reveal_type(NewManager()) # N: Revealed type is 'myapp.models.NewManager' + reveal_type(MyModel.objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]' + reveal_type(MyModel.objects.get()) # N: Revealed type is 'Any' + reveal_type(MyModel.objects.manager_only_method()) # N: Revealed type is 'builtins.int' + reveal_type(MyModel.objects.manager_and_queryset_method()) # N: Revealed type is 'builtins.str' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class ModelBaseManager(models.Manager): + def manager_only_method(self) -> int: + return 1 + class ModelQuerySet(models.QuerySet): + def manager_and_queryset_method(self) -> str: + return 'hello' + + NewManager = ModelBaseManager.from_queryset(ModelQuerySet, class_name='NewManager') + class MyModel(models.Model): + objects = NewManager() +