Add support for BaseManager.from_queryset() (#251)

* add support for BaseManager.from_queryset()

* cleanups

* lint fixes
This commit is contained in:
Maksim Kurnikov
2019-12-12 05:35:56 +03:00
committed by GitHub
parent b8f29027d8
commit ade48b6546
5 changed files with 273 additions and 91 deletions

View File

@@ -1,6 +1,6 @@
from collections import OrderedDict from collections import OrderedDict
from typing import ( from typing import (
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Union, cast, TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union,
) )
from django.db.models.fields import Field from django.db.models.fields import Field
@@ -10,13 +10,15 @@ from mypy import checker
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.mro import calculate_mro from mypy.mro import calculate_mro
from mypy.nodes import ( from mypy.nodes import (
GDEF, MDEF, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, GDEF, MDEF, Argument, Block, ClassDef, Expression, FuncDef, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode,
SymbolTableNode, TypeInfo, Var, SymbolTable, SymbolTableNode, TypeInfo, Var,
) )
from mypy.plugin import ( from mypy.plugin import (
AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext, AttributeContext, CheckerPluginInterface, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext,
) )
from mypy.types import AnyType, Instance, NoneTyp, TupleType from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy.types import TypedDictType, TypeOfAny, UnionType from mypy.types import TypedDictType, TypeOfAny, UnionType
@@ -55,7 +57,7 @@ def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile])
return sym.node return sym.node
def lookup_fully_qualified_typeinfo(api: TypeChecker, fullname: str) -> Optional[TypeInfo]: def lookup_fully_qualified_typeinfo(api: Union[TypeChecker, SemanticAnalyzer], fullname: str) -> Optional[TypeInfo]:
node = lookup_fully_qualified_generic(fullname, api.modules) node = lookup_fully_qualified_generic(fullname, api.modules)
if not isinstance(node, TypeInfo): if not isinstance(node, TypeInfo):
return None return None
@@ -173,8 +175,11 @@ def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]
return None return None
def add_new_class_for_module(module: MypyFile, name: str, bases: List[Instance], def add_new_class_for_module(module: MypyFile,
fields: 'OrderedDict[str, MypyType]') -> TypeInfo: name: str,
bases: List[Instance],
fields: Optional[Dict[str, MypyType]] = None
) -> TypeInfo:
new_class_unique_name = checker.gen_unique_name(name, module.names) new_class_unique_name = checker.gen_unique_name(name, module.names)
# make new class expression # make new class expression
@@ -188,11 +193,12 @@ def add_new_class_for_module(module: MypyFile, name: str, bases: List[Instance],
new_typeinfo.calculate_metaclass_type() new_typeinfo.calculate_metaclass_type()
# add fields # add fields
for field_name, field_type in fields.items(): if fields:
var = Var(field_name, type=field_type) for field_name, field_type in fields.items():
var.info = new_typeinfo var = Var(field_name, type=field_type)
var._fullname = new_typeinfo.fullname + '.' + field_name var.info = new_typeinfo
new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True) var._fullname = new_typeinfo.fullname + '.' + field_name
new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True)
classdef.info = new_typeinfo classdef.info = new_typeinfo
module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True) module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)
@@ -269,10 +275,16 @@ def resolve_string_attribute_value(attr_expr: Expression, ctx: Union[FunctionCon
return None return None
def get_semanal_api(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> SemanticAnalyzer:
if not isinstance(ctx.api, SemanticAnalyzer):
raise ValueError('Not a SemanticAnalyzer')
return ctx.api
def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker: def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker:
if not isinstance(ctx.api, TypeChecker): if not isinstance(ctx.api, TypeChecker):
raise ValueError('Not a TypeChecker') raise ValueError('Not a TypeChecker')
return cast(TypeChecker, ctx.api) return ctx.api
def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool: def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool:
@@ -298,3 +310,28 @@ def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> No
var.is_inferred = True var.is_inferred = True
info.names[name] = SymbolTableNode(MDEF, var, info.names[name] = SymbolTableNode(MDEF, var,
plugin_generated=True) plugin_generated=True)
def _prepare_new_method_arguments(node: FuncDef) -> Tuple[List[Argument], MypyType]:
arguments = []
for argument in node.arguments[1:]:
if argument.type_annotation is None:
argument.type_annotation = AnyType(TypeOfAny.unannotated)
arguments.append(argument)
if isinstance(node.type, CallableType):
return_type = node.type.ret_type
else:
return_type = AnyType(TypeOfAny.unannotated)
return arguments, return_type
def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance,
new_method_name: str, method_node: FuncDef) -> None:
arguments, return_type = _prepare_new_method_arguments(method_node)
add_method(ctx,
new_method_name,
args=arguments,
return_type=return_type,
self_type=self_type)

View File

@@ -7,7 +7,7 @@ from mypy.errors import Errors
from mypy.nodes import MypyFile, TypeInfo from mypy.nodes import MypyFile, TypeInfo
from mypy.options import Options from mypy.options import Options
from mypy.plugin import ( from mypy.plugin import (
AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin, AttributeContext, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext, Plugin,
) )
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
@@ -17,6 +17,9 @@ from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import ( from mypy_django_plugin.transformers import (
fields, forms, init_create, meta, querysets, request, settings, fields, forms, init_create, meta, querysets, request, settings,
) )
from mypy_django_plugin.transformers.managers import (
create_new_manager_class_from_from_queryset_method,
)
from mypy_django_plugin.transformers.models import process_model_class from mypy_django_plugin.transformers.models import process_model_class
@@ -242,6 +245,15 @@ class NewSemanalDjangoPlugin(Plugin):
return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context) return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context)
return None return None
def get_dynamic_class_hook(self, fullname: str
) -> Optional[Callable[[DynamicClassDefContext], None]]:
if fullname.endswith('from_queryset'):
class_name, _, _ = fullname.rpartition('.')
info = self._get_typeinfo_or_none(class_name)
if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
return create_new_manager_class_from_from_queryset_method
return None
def plugin(version): def plugin(version):
return NewSemanalDjangoPlugin return NewSemanalDjangoPlugin

View File

@@ -0,0 +1,71 @@
from mypy.nodes import (
GDEF, FuncDef, MemberExpr, NameExpr, StrExpr, SymbolTableNode, TypeInfo,
)
from mypy.plugin import ClassDefContext, DynamicClassDefContext
from mypy.types import AnyType, Instance, TypeOfAny
from mypy_django_plugin.lib import helpers
def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefContext) -> None:
semanal_api = helpers.get_semanal_api(ctx)
assert isinstance(ctx.call.callee, MemberExpr)
assert isinstance(ctx.call.callee.expr, NameExpr)
base_manager_info = ctx.call.callee.expr.node
if base_manager_info is None:
if not semanal_api.final_iteration:
semanal_api.defer()
return
assert isinstance(base_manager_info, TypeInfo)
new_manager_info = semanal_api.basic_new_typeinfo(ctx.name,
basetype_or_fallback=Instance(base_manager_info,
[AnyType(TypeOfAny.unannotated)]))
new_manager_info.line = ctx.call.line
new_manager_info.defn.line = ctx.call.line
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
current_module = semanal_api.cur_mod_node
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info,
plugin_generated=True)
passed_queryset = ctx.call.args[0]
assert isinstance(passed_queryset, NameExpr)
derived_queryset_fullname = passed_queryset.fullname
assert derived_queryset_fullname is not None
sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname)
assert sym is not None
if sym.node is None:
if not semanal_api.final_iteration:
semanal_api.defer()
else:
# inherit from Any to prevent false-positives, if queryset class cannot be resolved
new_manager_info.fallback_to_any = True
return
derived_queryset_info = sym.node
assert isinstance(derived_queryset_info, TypeInfo)
if len(ctx.call.args) > 1:
expr = ctx.call.args[1]
assert isinstance(expr, StrExpr)
custom_manager_generated_name = expr.value
else:
custom_manager_generated_name = base_manager_info.name + 'From' + derived_queryset_info.name
custom_manager_generated_fullname = '.'.join(['django.db.models.manager', custom_manager_generated_name])
if 'from_queryset_managers' not in base_manager_info.metadata:
base_manager_info.metadata['from_queryset_managers'] = {}
base_manager_info.metadata['from_queryset_managers'][custom_manager_generated_fullname] = new_manager_info.fullname
class_def_context = ClassDefContext(cls=new_manager_info.defn,
reason=ctx.call, api=semanal_api)
self_type = Instance(new_manager_info, [])
for name, sym in derived_queryset_info.names.items():
if isinstance(sym.node, FuncDef):
helpers.copy_method_to_another_class(class_def_context,
self_type,
new_method_name=name,
method_node=sym.node)

View File

@@ -1,5 +1,4 @@
from collections import OrderedDict from typing import Dict, Optional, Type, cast
from typing import List, Tuple, Type
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.fields import DateField, DateTimeField from django.db.models.fields import DateField, DateTimeField
@@ -8,10 +7,10 @@ from django.db.models.fields.reverse_related import (
ManyToManyRel, ManyToOneRel, OneToOneRel, ManyToManyRel, ManyToOneRel, OneToOneRel,
) )
from mypy.nodes import ARG_STAR2, Argument, Context, FuncDef, TypeInfo, Var from mypy.nodes import ARG_STAR2, Argument, Context, FuncDef, TypeInfo, Var
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface from mypy.plugin import ClassDefContext
from mypy.plugins import common from mypy.plugins import common
from mypy.plugins.common import add_method from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, CallableType, Instance from mypy.types import AnyType, Instance
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy.types import TypeOfAny from mypy.types import TypeOfAny
@@ -22,19 +21,22 @@ from mypy_django_plugin.transformers.fields import get_field_descriptor_types
class ModelClassInitializer: class ModelClassInitializer:
api: SemanticAnalyzerPluginInterface api: SemanticAnalyzer
def __init__(self, ctx: ClassDefContext, django_context: DjangoContext): def __init__(self, ctx: ClassDefContext, django_context: DjangoContext):
self.api = ctx.api self.api = cast(SemanticAnalyzer, ctx.api)
self.model_classdef = ctx.cls self.model_classdef = ctx.cls
self.django_context = django_context self.django_context = django_context
self.ctx = ctx self.ctx = ctx
def lookup_typeinfo(self, fullname: str) -> Optional[TypeInfo]:
return helpers.lookup_fully_qualified_typeinfo(self.api, fullname)
def lookup_typeinfo_or_incomplete_defn_error(self, fullname: str) -> TypeInfo: def lookup_typeinfo_or_incomplete_defn_error(self, fullname: str) -> TypeInfo:
sym = self.api.lookup_fully_qualified_or_none(fullname) info = self.lookup_typeinfo(fullname)
if sym is None or not isinstance(sym.node, TypeInfo): if info is None:
raise helpers.IncompleteDefnException(f'No {fullname!r} found') raise helpers.IncompleteDefnException(f'No {fullname!r} found')
return sym.node return info
def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInfo: def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInfo:
fullname = helpers.get_class_fullname(klass) fullname = helpers.get_class_fullname(klass)
@@ -135,11 +137,64 @@ class AddRelatedModelsId(ModelClassInitializer):
class AddManagers(ModelClassInitializer): class AddManagers(ModelClassInitializer):
def _is_manager_any(self, typ: Instance) -> bool: def has_any_parametrized_manager_as_base(self, info: TypeInfo) -> bool:
return typ.type.fullname == fullnames.MANAGER_CLASS_FULLNAME and type(typ.args[0]) == AnyType for base in helpers.iter_bases(info):
if self.is_any_parametrized_manager(base):
return True
return False
def is_any_parametrized_manager(self, typ: Instance) -> bool:
return typ.type.fullname == fullnames.MANAGER_CLASS_FULLNAME and isinstance(typ.args[0], AnyType)
def get_generated_manager_mappings(self, base_manager_fullname: str) -> Dict[str, str]:
base_manager_info = self.lookup_typeinfo(base_manager_fullname)
if (base_manager_info is None
or 'from_queryset_managers' not in base_manager_info.metadata):
return {}
return base_manager_info.metadata['from_queryset_managers']
def create_new_model_parametrized_manager(self, name: str, base_manager_info: TypeInfo) -> Instance:
bases = []
for original_base in base_manager_info.bases:
if self.is_any_parametrized_manager(original_base):
if original_base.type is None:
raise helpers.IncompleteDefnException()
original_base = helpers.reparametrize_instance(original_base,
[Instance(self.model_classdef.info, [])])
bases.append(original_base)
current_module = self.api.modules[self.model_classdef.info.module_name]
custom_manager_info = helpers.add_new_class_for_module(current_module,
name=name, bases=bases)
# copy fields to a new manager
new_cls_def_context = ClassDefContext(cls=custom_manager_info.defn,
reason=self.ctx.reason,
api=self.api)
custom_manager_type = Instance(custom_manager_info, [Instance(self.model_classdef.info, [])])
for name, sym in base_manager_info.names.items():
# replace self type with new class, if copying method
if isinstance(sym.node, FuncDef):
helpers.copy_method_to_another_class(new_cls_def_context,
self_type=custom_manager_type,
new_method_name=name,
method_node=sym.node)
continue
new_sym = sym.copy()
if isinstance(new_sym.node, Var):
new_var = Var(name, type=sym.type)
new_var.info = custom_manager_info
new_var._fullname = custom_manager_info.fullname + '.' + name
new_sym.node = new_var
custom_manager_info.names[name] = new_sym
return custom_manager_type
def run_with_model_cls(self, model_cls: Type[Model]) -> None: def run_with_model_cls(self, model_cls: Type[Model]) -> None:
for manager_name, manager in model_cls._meta.managers_map.items(): for manager_name, manager in model_cls._meta.managers_map.items():
manager_class_name = manager.__class__.__name__
manager_fullname = helpers.get_class_fullname(manager.__class__) manager_fullname = helpers.get_class_fullname(manager.__class__)
try: try:
manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname) manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname)
@@ -147,78 +202,33 @@ class AddManagers(ModelClassInitializer):
if not self.api.final_iteration: if not self.api.final_iteration:
raise exc raise exc
else: else:
continue base_manager_fullname = helpers.get_class_fullname(manager.__class__.__bases__[0])
generated_managers = self.get_generated_manager_mappings(base_manager_fullname)
if manager_fullname not in generated_managers:
# not a generated manager, continue with the loop
continue
real_manager_fullname = generated_managers[manager_fullname]
manager_info = self.lookup_typeinfo(real_manager_fullname) # type: ignore
if manager_info is None:
continue
manager_class_name = real_manager_fullname.rsplit('.', maxsplit=1)[1]
if manager_name not in self.model_classdef.info.names: if manager_name not in self.model_classdef.info.names:
manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])]) manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])])
self.add_new_node_to_model_class(manager_name, manager_type) self.add_new_node_to_model_class(manager_name, manager_type)
else: else:
# creates new MODELNAME_MANAGERCLASSNAME class that represents manager parametrized with current model # creates new MODELNAME_MANAGERCLASSNAME class that represents manager parametrized with current model
has_manager_any_base = any(self._is_manager_any(base) for base in manager_info.bases) if not self.has_any_parametrized_manager_as_base(manager_info):
if has_manager_any_base: continue
custom_model_manager_name = manager.model.__name__ + '_' + manager.__class__.__name__
try: custom_model_manager_name = manager.model.__name__ + '_' + manager_class_name
bases = [] try:
for original_base in manager_info.bases: custom_manager_type = self.create_new_model_parametrized_manager(custom_model_manager_name,
if self._is_manager_any(original_base): base_manager_info=manager_info)
if original_base.type is None: except helpers.IncompleteDefnException:
raise helpers.IncompleteDefnException() continue
original_base = helpers.reparametrize_instance(original_base, self.add_new_node_to_model_class(manager_name, custom_manager_type)
[Instance(self.model_classdef.info, [])])
bases.append(original_base)
except helpers.IncompleteDefnException as exc:
if not self.api.final_iteration:
raise exc
else:
continue
current_module = self.api.modules[self.model_classdef.info.module_name]
custom_manager_info = helpers.add_new_class_for_module(current_module,
custom_model_manager_name,
bases=bases,
fields=OrderedDict())
# copy fields to a new manager
new_cls_def_context = ClassDefContext(cls=custom_manager_info.defn,
reason=self.ctx.reason,
api=self.api)
custom_manager_type = Instance(custom_manager_info, [Instance(self.model_classdef.info, [])])
for name, sym in manager_info.names.items():
# replace self type with new class, if copying method
if isinstance(sym.node, FuncDef):
arguments, return_type = self.prepare_new_method_arguments(sym.node)
add_method(new_cls_def_context,
name,
args=arguments,
return_type=return_type,
self_type=custom_manager_type)
continue
new_sym = sym.copy()
if isinstance(new_sym.node, Var):
new_var = Var(name, type=sym.type)
new_var.info = custom_manager_info
new_var._fullname = custom_manager_info.fullname + '.' + name
new_sym.node = new_var
custom_manager_info.names[name] = new_sym
self.add_new_node_to_model_class(manager_name, custom_manager_type)
def prepare_new_method_arguments(self, node: FuncDef) -> Tuple[List[Argument], MypyType]:
arguments = []
for argument in node.arguments[1:]:
if argument.type_annotation is None:
argument.type_annotation = AnyType(TypeOfAny.unannotated)
arguments.append(argument)
if isinstance(node.type, CallableType):
return_type = node.type.ret_type
else:
return_type = AnyType(TypeOfAny.unannotated)
return arguments, return_type
class AddDefaultManagerAttribute(ModelClassInitializer): class AddDefaultManagerAttribute(ModelClassInitializer):

View File

@@ -0,0 +1,52 @@
- case: test_from_queryset_returns_intersection_of_manager_and_queryset
main: |
from myapp.models import MyModel, NewManager
reveal_type(NewManager()) # N: Revealed type is 'myapp.models.NewManager'
reveal_type(MyModel.objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]'
reveal_type(MyModel.objects.get()) # N: Revealed type is 'Any'
reveal_type(MyModel.objects.manager_only_method()) # N: Revealed type is 'builtins.int'
reveal_type(MyModel.objects.manager_and_queryset_method()) # N: Revealed type is 'builtins.str'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class ModelBaseManager(models.Manager):
def manager_only_method(self) -> int:
return 1
class ModelQuerySet(models.QuerySet):
def manager_and_queryset_method(self) -> str:
return 'hello'
NewManager = ModelBaseManager.from_queryset(ModelQuerySet)
class MyModel(models.Model):
objects = NewManager()
- case: test_from_queryset_with_class_name_provided
main: |
from myapp.models import MyModel, NewManager
reveal_type(NewManager()) # N: Revealed type is 'myapp.models.NewManager'
reveal_type(MyModel.objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]'
reveal_type(MyModel.objects.get()) # N: Revealed type is 'Any'
reveal_type(MyModel.objects.manager_only_method()) # N: Revealed type is 'builtins.int'
reveal_type(MyModel.objects.manager_and_queryset_method()) # N: Revealed type is 'builtins.str'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class ModelBaseManager(models.Manager):
def manager_only_method(self) -> int:
return 1
class ModelQuerySet(models.QuerySet):
def manager_and_queryset_method(self) -> str:
return 'hello'
NewManager = ModelBaseManager.from_queryset(ModelQuerySet, class_name='NewManager')
class MyModel(models.Model):
objects = NewManager()