Fix/673/from queryset then custom qs method (#680)

* Fix `MyModel.objects.filter(...).my_method()`

* Fix regression: `MyModel.objects.filter(...).my_method()` no longer worked when using from_queryset

This also fixes the self-type of the copied-over methods of the manager generated by from_queryset.
Previously it was not parameterized by the model class, but used Any.

The handling of unbound types is not tested here as I have not been able to
find a way to create a test case for it. It has been manually tested
against an internal codebase.

* Remove unneeded defer.
This commit is contained in:
Seth Yastrov
2021-07-30 00:01:39 +02:00
committed by GitHub
parent 08a662ecb1
commit 8da8ab4862
4 changed files with 129 additions and 33 deletions

View File

@@ -18,7 +18,6 @@ from mypy.nodes import (
MemberExpr,
MypyFile,
NameExpr,
PlaceholderNode,
StrExpr,
SymbolNode,
SymbolTable,
@@ -33,12 +32,13 @@ from mypy.plugin import (
DynamicClassDefContext,
FunctionContext,
MethodContext,
SemanticAnalyzerPluginInterface,
)
from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType
from mypy.types import Type as MypyType
from mypy.types import TypedDictType, TypeOfAny, UnionType
from mypy.types import TypedDictType, TypeOfAny, UnboundType, UnionType
from mypy_django_plugin.lib import fullnames
from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME
@@ -355,8 +355,26 @@ def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument],
return prepared_arguments, return_type
def bind_or_analyze_type(t: MypyType, api: SemanticAnalyzer, module_name: Optional[str] = None) -> Optional[MypyType]:
"""Analyze a type. If an unbound type, try to look it up in the given module name.
That should hopefully give a bound type."""
if isinstance(t, UnboundType) and module_name is not None:
node = api.lookup_fully_qualified_or_none(module_name + "." + t.name)
if node is None:
return None
return node.type
else:
return api.anal_type(t)
def copy_method_to_another_class(
ctx: ClassDefContext, self_type: Instance, new_method_name: str, method_node: FuncDef
ctx: ClassDefContext,
self_type: Instance,
new_method_name: str,
method_node: FuncDef,
return_type: Optional[MypyType] = None,
original_module_name: Optional[str] = None,
) -> None:
semanal_api = get_semanal_api(ctx)
if method_node.type is None:
@@ -374,23 +392,20 @@ def copy_method_to_another_class(
semanal_api.defer()
return
arguments = []
bound_return_type = semanal_api.anal_type(method_type.ret_type, allow_placeholder=True)
assert bound_return_type is not None
if isinstance(bound_return_type, PlaceholderNode):
if return_type is None:
return_type = bind_or_analyze_type(method_type.ret_type, semanal_api, original_module_name)
if return_type is None:
return
try:
original_arguments = method_node.arguments[1:]
except AttributeError:
original_arguments = []
arguments = []
for arg_name, arg_type, original_argument in zip(
method_type.arg_names[1:], method_type.arg_types[1:], original_arguments
):
bound_arg_type = semanal_api.anal_type(arg_type)
bound_arg_type = bind_or_analyze_type(arg_type, semanal_api, original_module_name)
if bound_arg_type is None:
return
@@ -406,4 +421,10 @@ def copy_method_to_another_class(
argument.set_line(original_argument)
arguments.append(argument)
add_method(ctx, new_method_name, args=arguments, return_type=bound_return_type, self_type=self_type)
add_method(ctx, new_method_name, args=arguments, return_type=return_type, self_type=self_type)
def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None:
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
get_django_metadata(sym.node)["manager_bases"][fullname] = 1

View File

@@ -53,10 +53,8 @@ def transform_form_class(ctx: ClassDefContext) -> None:
forms.make_meta_nested_class_inherit_from_any(ctx)
def add_new_manager_base(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
helpers.get_django_metadata(sym.node)["manager_bases"][ctx.cls.fullname] = 1
def add_new_manager_base_hook(ctx: ClassDefContext) -> None:
helpers.add_new_manager_base(ctx.api, ctx.cls.fullname)
def extract_django_settings_module(config_file_path: Optional[str]) -> str:
@@ -235,7 +233,12 @@ class NewSemanalDjangoPlugin(Plugin):
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname:
deps.add(self._new_dependency(related_model_module))
return list(deps) + [self._new_dependency("django_stubs_ext")] # for annotate
return list(deps) + [
# for QuerySet.annotate
self._new_dependency("django_stubs_ext"),
# For BaseManager.from_queryset
self._new_dependency("django.db.models.query"),
]
def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], MypyType]]:
if fullname == "django.contrib.auth.get_user_model":
@@ -305,7 +308,7 @@ class NewSemanalDjangoPlugin(Plugin):
return partial(transform_model_class, django_context=self.django_context)
if fullname in self._get_current_manager_bases():
return add_new_manager_base
return add_new_manager_base_hook
if fullname in self._get_current_form_bases():
return transform_form_class

View File

@@ -1,6 +1,7 @@
from mypy.checker import fill_typevars
from mypy.nodes import GDEF, Decorator, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo
from mypy.plugin import ClassDefContext, DynamicClassDefContext
from mypy.types import AnyType, Instance, TypeOfAny
from mypy.types import CallableType, Instance, TypeVarType, UnboundType, get_proper_type
from mypy_django_plugin.lib import fullnames, helpers
@@ -29,15 +30,11 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
# But it should be analyzed again, so this isn't a problem.
return
base_manager_instance = fill_typevars(base_manager_info)
assert isinstance(base_manager_instance, Instance)
new_manager_info = semanal_api.basic_new_typeinfo(
ctx.name, basetype_or_fallback=Instance(base_manager_info, [AnyType(TypeOfAny.unannotated)]), line=ctx.call.line
ctx.name, basetype_or_fallback=base_manager_instance, line=ctx.call.line
)
new_manager_info.line = ctx.call.line
new_manager_info.defn.line = ctx.call.line
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
current_module = semanal_api.cur_mod_node
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname)
assert sym is not None
@@ -52,6 +49,15 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
derived_queryset_info = sym.node
assert isinstance(derived_queryset_info, TypeInfo)
new_manager_info.line = ctx.call.line
new_manager_info.type_vars = base_manager_info.type_vars
new_manager_info.defn.type_vars = base_manager_info.defn.type_vars
new_manager_info.defn.line = ctx.call.line
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
current_module = semanal_api.cur_mod_node
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
if len(ctx.call.args) > 1:
expr = ctx.call.args[1]
assert isinstance(expr, StrExpr)
@@ -64,11 +70,19 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
base_manager_info.metadata["from_queryset_managers"] = {}
base_manager_info.metadata["from_queryset_managers"][custom_manager_generated_fullname] = new_manager_info.fullname
# So that the plugin will reparameterize the manager when it is constructed inside of a Model definition
helpers.add_new_manager_base(semanal_api, new_manager_info.fullname)
class_def_context = ClassDefContext(cls=new_manager_info.defn, reason=ctx.call, api=semanal_api)
self_type = Instance(new_manager_info, [])
self_type = fill_typevars(new_manager_info)
assert isinstance(self_type, Instance)
queryset_method_names = []
# we need to copy all methods in MRO before django.db.models.query.QuerySet
for class_mro_info in derived_queryset_info.mro:
if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME:
for name, sym in class_mro_info.names.items():
queryset_method_names.append(name)
break
for name, sym in class_mro_info.names.items():
if isinstance(sym.node, FuncDef):
@@ -80,3 +94,59 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
helpers.copy_method_to_another_class(
class_def_context, self_type, new_method_name=name, method_node=func_node
)
# Gather names of all BaseManager methods
manager_method_names = []
for manager_mro_info in new_manager_info.mro:
if manager_mro_info.fullname == fullnames.BASE_MANAGER_CLASS_FULLNAME:
for name, sym in manager_mro_info.names.items():
manager_method_names.append(name)
# Copy/alter all methods in common between BaseManager/QuerySet over to the new manager if their return type is
# the QuerySet's self-type. Alter the return type to be the custom queryset, parameterized by the manager's model
# type variable.
for class_mro_info in derived_queryset_info.mro:
if class_mro_info.fullname != fullnames.QUERYSET_CLASS_FULLNAME:
continue
for name, sym in class_mro_info.names.items():
if name not in manager_method_names:
continue
if isinstance(sym.node, FuncDef):
func_node = sym.node
elif isinstance(sym.node, Decorator):
func_node = sym.node.func
else:
continue
method_type = func_node.type
if not isinstance(method_type, CallableType):
if not semanal_api.final_iteration:
semanal_api.defer()
return None
original_return_type = method_type.ret_type
if original_return_type is None:
continue
# Skip any method that doesn't return _QS
original_return_type = get_proper_type(original_return_type)
if isinstance(original_return_type, UnboundType):
if original_return_type.name != "_QS":
continue
elif isinstance(original_return_type, TypeVarType):
if original_return_type.name != "_QS":
continue
else:
continue
# Return the custom queryset parameterized by the manager's type vars
return_type = Instance(derived_queryset_info, self_type.args)
helpers.copy_method_to_another_class(
class_def_context,
self_type,
new_method_name=name,
method_node=func_node,
return_type=return_type,
original_module_name=class_mro_info.module_name,
)

View File

@@ -1,9 +1,11 @@
- case: from_queryset_with_base_manager
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
reveal_type(MyModel.objects.filter(id=1).queryset_method()) # N: Revealed type is "builtins.str"
reveal_type(MyModel.objects.filter(id=1)) # N: Revealed type is "myapp.models.ModelQuerySet[myapp.models.MyModel*]"
installed_apps:
- myapp
files:
@@ -23,7 +25,7 @@
- case: from_queryset_with_manager
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
installed_apps:
@@ -97,7 +99,7 @@
- case: from_queryset_with_class_inheritance
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
installed_apps:
@@ -121,7 +123,7 @@
- case: from_queryset_with_manager_in_another_directory_and_imports
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects) # N: Revealed type is "myapp.managers.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.queryset_method) # N: Revealed type is "def (param: Union[builtins.str, None] =) -> Union[builtins.str, None]"
reveal_type(MyModel().objects.queryset_method('str')) # N: Revealed type is "Union[builtins.str, None]"
@@ -151,7 +153,7 @@
disable_cache: true
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects) # N: Revealed type is "myapp.managers.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
reveal_type(MyModel().objects.base_queryset_method) # N: Revealed type is "def (param: Union[builtins.int, builtins.str]) -> <nothing>"
reveal_type(MyModel().objects.base_queryset_method(2)) # N: Revealed type is "<nothing>"
@@ -183,7 +185,7 @@
- case: from_queryset_with_decorated_queryset_methods
main: |
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
installed_apps:
- myapp