Fix/673/from queryset then custom qs method (#680)

* Fix `MyModel.objects.filter(...).my_method()`

* Fix regression: `MyModel.objects.filter(...).my_method()` no longer worked when using from_queryset

This also fixes the self-type of the copied-over methods of the manager generated by from_queryset.
Previously it was not parameterized by the model class, but used Any.

The handling of unbound types is not tested here as I have not been able to
find a way to create a test case for it. It has been manually tested
against an internal codebase.

* Remove unneeded defer.
This commit is contained in:
Seth Yastrov
2021-07-30 00:01:39 +02:00
committed by GitHub
parent 08a662ecb1
commit 8da8ab4862
4 changed files with 129 additions and 33 deletions

View File

@@ -18,7 +18,6 @@ from mypy.nodes import (
MemberExpr, MemberExpr,
MypyFile, MypyFile,
NameExpr, NameExpr,
PlaceholderNode,
StrExpr, StrExpr,
SymbolNode, SymbolNode,
SymbolTable, SymbolTable,
@@ -33,12 +32,13 @@ from mypy.plugin import (
DynamicClassDefContext, DynamicClassDefContext,
FunctionContext, FunctionContext,
MethodContext, MethodContext,
SemanticAnalyzerPluginInterface,
) )
from mypy.plugins.common import add_method from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzer from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy.types import TypedDictType, TypeOfAny, UnionType from mypy.types import TypedDictType, TypeOfAny, UnboundType, UnionType
from mypy_django_plugin.lib import fullnames from mypy_django_plugin.lib import fullnames
from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME
@@ -355,8 +355,26 @@ def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument],
return prepared_arguments, return_type return prepared_arguments, return_type
def bind_or_analyze_type(t: MypyType, api: SemanticAnalyzer, module_name: Optional[str] = None) -> Optional[MypyType]:
"""Analyze a type. If an unbound type, try to look it up in the given module name.
That should hopefully give a bound type."""
if isinstance(t, UnboundType) and module_name is not None:
node = api.lookup_fully_qualified_or_none(module_name + "." + t.name)
if node is None:
return None
return node.type
else:
return api.anal_type(t)
def copy_method_to_another_class( def copy_method_to_another_class(
ctx: ClassDefContext, self_type: Instance, new_method_name: str, method_node: FuncDef ctx: ClassDefContext,
self_type: Instance,
new_method_name: str,
method_node: FuncDef,
return_type: Optional[MypyType] = None,
original_module_name: Optional[str] = None,
) -> None: ) -> None:
semanal_api = get_semanal_api(ctx) semanal_api = get_semanal_api(ctx)
if method_node.type is None: if method_node.type is None:
@@ -374,23 +392,20 @@ def copy_method_to_another_class(
semanal_api.defer() semanal_api.defer()
return return
arguments = [] if return_type is None:
bound_return_type = semanal_api.anal_type(method_type.ret_type, allow_placeholder=True) return_type = bind_or_analyze_type(method_type.ret_type, semanal_api, original_module_name)
if return_type is None:
assert bound_return_type is not None
if isinstance(bound_return_type, PlaceholderNode):
return return
try: try:
original_arguments = method_node.arguments[1:] original_arguments = method_node.arguments[1:]
except AttributeError: except AttributeError:
original_arguments = [] original_arguments = []
arguments = []
for arg_name, arg_type, original_argument in zip( for arg_name, arg_type, original_argument in zip(
method_type.arg_names[1:], method_type.arg_types[1:], original_arguments method_type.arg_names[1:], method_type.arg_types[1:], original_arguments
): ):
bound_arg_type = semanal_api.anal_type(arg_type) bound_arg_type = bind_or_analyze_type(arg_type, semanal_api, original_module_name)
if bound_arg_type is None: if bound_arg_type is None:
return return
@@ -406,4 +421,10 @@ def copy_method_to_another_class(
argument.set_line(original_argument) argument.set_line(original_argument)
arguments.append(argument) arguments.append(argument)
add_method(ctx, new_method_name, args=arguments, return_type=bound_return_type, self_type=self_type) add_method(ctx, new_method_name, args=arguments, return_type=return_type, self_type=self_type)
def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None:
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
get_django_metadata(sym.node)["manager_bases"][fullname] = 1

View File

@@ -53,10 +53,8 @@ def transform_form_class(ctx: ClassDefContext) -> None:
forms.make_meta_nested_class_inherit_from_any(ctx) forms.make_meta_nested_class_inherit_from_any(ctx)
def add_new_manager_base(ctx: ClassDefContext) -> None: def add_new_manager_base_hook(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME) helpers.add_new_manager_base(ctx.api, ctx.cls.fullname)
if sym is not None and isinstance(sym.node, TypeInfo):
helpers.get_django_metadata(sym.node)["manager_bases"][ctx.cls.fullname] = 1
def extract_django_settings_module(config_file_path: Optional[str]) -> str: def extract_django_settings_module(config_file_path: Optional[str]) -> str:
@@ -235,7 +233,12 @@ class NewSemanalDjangoPlugin(Plugin):
related_model_module = related_model_cls.__module__ related_model_module = related_model_cls.__module__
if related_model_module != file.fullname: if related_model_module != file.fullname:
deps.add(self._new_dependency(related_model_module)) deps.add(self._new_dependency(related_model_module))
return list(deps) + [self._new_dependency("django_stubs_ext")] # for annotate return list(deps) + [
# for QuerySet.annotate
self._new_dependency("django_stubs_ext"),
# For BaseManager.from_queryset
self._new_dependency("django.db.models.query"),
]
def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], MypyType]]: def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], MypyType]]:
if fullname == "django.contrib.auth.get_user_model": if fullname == "django.contrib.auth.get_user_model":
@@ -305,7 +308,7 @@ class NewSemanalDjangoPlugin(Plugin):
return partial(transform_model_class, django_context=self.django_context) return partial(transform_model_class, django_context=self.django_context)
if fullname in self._get_current_manager_bases(): if fullname in self._get_current_manager_bases():
return add_new_manager_base return add_new_manager_base_hook
if fullname in self._get_current_form_bases(): if fullname in self._get_current_form_bases():
return transform_form_class return transform_form_class

View File

@@ -1,6 +1,7 @@
from mypy.checker import fill_typevars
from mypy.nodes import GDEF, Decorator, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo from mypy.nodes import GDEF, Decorator, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo
from mypy.plugin import ClassDefContext, DynamicClassDefContext from mypy.plugin import ClassDefContext, DynamicClassDefContext
from mypy.types import AnyType, Instance, TypeOfAny from mypy.types import CallableType, Instance, TypeVarType, UnboundType, get_proper_type
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, helpers
@@ -29,15 +30,11 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
# But it should be analyzed again, so this isn't a problem. # But it should be analyzed again, so this isn't a problem.
return return
base_manager_instance = fill_typevars(base_manager_info)
assert isinstance(base_manager_instance, Instance)
new_manager_info = semanal_api.basic_new_typeinfo( new_manager_info = semanal_api.basic_new_typeinfo(
ctx.name, basetype_or_fallback=Instance(base_manager_info, [AnyType(TypeOfAny.unannotated)]), line=ctx.call.line ctx.name, basetype_or_fallback=base_manager_instance, line=ctx.call.line
) )
new_manager_info.line = ctx.call.line
new_manager_info.defn.line = ctx.call.line
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
current_module = semanal_api.cur_mod_node
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname) sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname)
assert sym is not None assert sym is not None
@@ -52,6 +49,15 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
derived_queryset_info = sym.node derived_queryset_info = sym.node
assert isinstance(derived_queryset_info, TypeInfo) assert isinstance(derived_queryset_info, TypeInfo)
new_manager_info.line = ctx.call.line
new_manager_info.type_vars = base_manager_info.type_vars
new_manager_info.defn.type_vars = base_manager_info.defn.type_vars
new_manager_info.defn.line = ctx.call.line
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
current_module = semanal_api.cur_mod_node
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
if len(ctx.call.args) > 1: if len(ctx.call.args) > 1:
expr = ctx.call.args[1] expr = ctx.call.args[1]
assert isinstance(expr, StrExpr) assert isinstance(expr, StrExpr)
@@ -64,11 +70,19 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
base_manager_info.metadata["from_queryset_managers"] = {} base_manager_info.metadata["from_queryset_managers"] = {}
base_manager_info.metadata["from_queryset_managers"][custom_manager_generated_fullname] = new_manager_info.fullname base_manager_info.metadata["from_queryset_managers"][custom_manager_generated_fullname] = new_manager_info.fullname
# So that the plugin will reparameterize the manager when it is constructed inside of a Model definition
helpers.add_new_manager_base(semanal_api, new_manager_info.fullname)
class_def_context = ClassDefContext(cls=new_manager_info.defn, reason=ctx.call, api=semanal_api) class_def_context = ClassDefContext(cls=new_manager_info.defn, reason=ctx.call, api=semanal_api)
self_type = Instance(new_manager_info, []) self_type = fill_typevars(new_manager_info)
assert isinstance(self_type, Instance)
queryset_method_names = []
# we need to copy all methods in MRO before django.db.models.query.QuerySet # we need to copy all methods in MRO before django.db.models.query.QuerySet
for class_mro_info in derived_queryset_info.mro: for class_mro_info in derived_queryset_info.mro:
if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME: if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME:
for name, sym in class_mro_info.names.items():
queryset_method_names.append(name)
break break
for name, sym in class_mro_info.names.items(): for name, sym in class_mro_info.names.items():
if isinstance(sym.node, FuncDef): if isinstance(sym.node, FuncDef):
@@ -80,3 +94,59 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
helpers.copy_method_to_another_class( helpers.copy_method_to_another_class(
class_def_context, self_type, new_method_name=name, method_node=func_node class_def_context, self_type, new_method_name=name, method_node=func_node
) )
# Gather names of all BaseManager methods
manager_method_names = []
for manager_mro_info in new_manager_info.mro:
if manager_mro_info.fullname == fullnames.BASE_MANAGER_CLASS_FULLNAME:
for name, sym in manager_mro_info.names.items():
manager_method_names.append(name)
# Copy/alter all methods in common between BaseManager/QuerySet over to the new manager if their return type is
# the QuerySet's self-type. Alter the return type to be the custom queryset, parameterized by the manager's model
# type variable.
for class_mro_info in derived_queryset_info.mro:
if class_mro_info.fullname != fullnames.QUERYSET_CLASS_FULLNAME:
continue
for name, sym in class_mro_info.names.items():
if name not in manager_method_names:
continue
if isinstance(sym.node, FuncDef):
func_node = sym.node
elif isinstance(sym.node, Decorator):
func_node = sym.node.func
else:
continue
method_type = func_node.type
if not isinstance(method_type, CallableType):
if not semanal_api.final_iteration:
semanal_api.defer()
return None
original_return_type = method_type.ret_type
if original_return_type is None:
continue
# Skip any method that doesn't return _QS
original_return_type = get_proper_type(original_return_type)
if isinstance(original_return_type, UnboundType):
if original_return_type.name != "_QS":
continue
elif isinstance(original_return_type, TypeVarType):
if original_return_type.name != "_QS":
continue
else:
continue
# Return the custom queryset parameterized by the manager's type vars
return_type = Instance(derived_queryset_info, self_type.args)
helpers.copy_method_to_another_class(
class_def_context,
self_type,
new_method_name=name,
method_node=func_node,
return_type=return_type,
original_module_name=class_mro_info.module_name,
)

View File

@@ -1,9 +1,11 @@
- case: from_queryset_with_base_manager - case: from_queryset_with_base_manager
main: | main: |
from myapp.models import MyModel from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*" reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str" reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
reveal_type(MyModel.objects.filter(id=1).queryset_method()) # N: Revealed type is "builtins.str"
reveal_type(MyModel.objects.filter(id=1)) # N: Revealed type is "myapp.models.ModelQuerySet[myapp.models.MyModel*]"
installed_apps: installed_apps:
- myapp - myapp
files: files:
@@ -23,7 +25,7 @@
- case: from_queryset_with_manager - case: from_queryset_with_manager
main: | main: |
from myapp.models import MyModel from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*" reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str" reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
installed_apps: installed_apps:
@@ -97,7 +99,7 @@
- case: from_queryset_with_class_inheritance - case: from_queryset_with_class_inheritance
main: | main: |
from myapp.models import MyModel from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*" reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str" reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
installed_apps: installed_apps:
@@ -121,7 +123,7 @@
- case: from_queryset_with_manager_in_another_directory_and_imports - case: from_queryset_with_manager_in_another_directory_and_imports
main: | main: |
from myapp.models import MyModel from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects) # N: Revealed type is "myapp.managers.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*" reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.queryset_method) # N: Revealed type is "def (param: Union[builtins.str, None] =) -> Union[builtins.str, None]" reveal_type(MyModel().objects.queryset_method) # N: Revealed type is "def (param: Union[builtins.str, None] =) -> Union[builtins.str, None]"
reveal_type(MyModel().objects.queryset_method('str')) # N: Revealed type is "Union[builtins.str, None]" reveal_type(MyModel().objects.queryset_method('str')) # N: Revealed type is "Union[builtins.str, None]"
@@ -151,7 +153,7 @@
disable_cache: true disable_cache: true
main: | main: |
from myapp.models import MyModel from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects) # N: Revealed type is "myapp.managers.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*" reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.base_queryset_method) # N: Revealed type is "def (param: Union[builtins.int, builtins.str]) -> <nothing>" reveal_type(MyModel().objects.base_queryset_method) # N: Revealed type is "def (param: Union[builtins.int, builtins.str]) -> <nothing>"
reveal_type(MyModel().objects.base_queryset_method(2)) # N: Revealed type is "<nothing>" reveal_type(MyModel().objects.base_queryset_method(2)) # N: Revealed type is "<nothing>"
@@ -183,7 +185,7 @@
- case: from_queryset_with_decorated_queryset_methods - case: from_queryset_with_decorated_queryset_methods
main: | main: |
from myapp.models import MyModel from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str" reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
installed_apps: installed_apps:
- myapp - myapp