mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-11 06:21:58 +08:00
Fix/673/from queryset then custom qs method (#680)
* Fix `MyModel.objects.filter(...).my_method()` * Fix regression: `MyModel.objects.filter(...).my_method()` no longer worked when using from_queryset This also fixes the self-type of the copied-over methods of the manager generated by from_queryset. Previously it was not parameterized by the model class, but used Any. The handling of unbound types is not tested here as I have not been able to find a way to create a test case for it. It has been manually tested against an internal codebase. * Remove unneeded defer.
This commit is contained in:
@@ -18,7 +18,6 @@ from mypy.nodes import (
|
||||
MemberExpr,
|
||||
MypyFile,
|
||||
NameExpr,
|
||||
PlaceholderNode,
|
||||
StrExpr,
|
||||
SymbolNode,
|
||||
SymbolTable,
|
||||
@@ -33,12 +32,13 @@ from mypy.plugin import (
|
||||
DynamicClassDefContext,
|
||||
FunctionContext,
|
||||
MethodContext,
|
||||
SemanticAnalyzerPluginInterface,
|
||||
)
|
||||
from mypy.plugins.common import add_method
|
||||
from mypy.semanal import SemanticAnalyzer
|
||||
from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType
|
||||
from mypy.types import Type as MypyType
|
||||
from mypy.types import TypedDictType, TypeOfAny, UnionType
|
||||
from mypy.types import TypedDictType, TypeOfAny, UnboundType, UnionType
|
||||
|
||||
from mypy_django_plugin.lib import fullnames
|
||||
from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME
|
||||
@@ -355,8 +355,26 @@ def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument],
|
||||
return prepared_arguments, return_type
|
||||
|
||||
|
||||
def bind_or_analyze_type(t: MypyType, api: SemanticAnalyzer, module_name: Optional[str] = None) -> Optional[MypyType]:
|
||||
"""Analyze a type. If an unbound type, try to look it up in the given module name.
|
||||
|
||||
That should hopefully give a bound type."""
|
||||
if isinstance(t, UnboundType) and module_name is not None:
|
||||
node = api.lookup_fully_qualified_or_none(module_name + "." + t.name)
|
||||
if node is None:
|
||||
return None
|
||||
return node.type
|
||||
else:
|
||||
return api.anal_type(t)
|
||||
|
||||
|
||||
def copy_method_to_another_class(
|
||||
ctx: ClassDefContext, self_type: Instance, new_method_name: str, method_node: FuncDef
|
||||
ctx: ClassDefContext,
|
||||
self_type: Instance,
|
||||
new_method_name: str,
|
||||
method_node: FuncDef,
|
||||
return_type: Optional[MypyType] = None,
|
||||
original_module_name: Optional[str] = None,
|
||||
) -> None:
|
||||
semanal_api = get_semanal_api(ctx)
|
||||
if method_node.type is None:
|
||||
@@ -374,23 +392,20 @@ def copy_method_to_another_class(
|
||||
semanal_api.defer()
|
||||
return
|
||||
|
||||
arguments = []
|
||||
bound_return_type = semanal_api.anal_type(method_type.ret_type, allow_placeholder=True)
|
||||
|
||||
assert bound_return_type is not None
|
||||
|
||||
if isinstance(bound_return_type, PlaceholderNode):
|
||||
if return_type is None:
|
||||
return_type = bind_or_analyze_type(method_type.ret_type, semanal_api, original_module_name)
|
||||
if return_type is None:
|
||||
return
|
||||
|
||||
try:
|
||||
original_arguments = method_node.arguments[1:]
|
||||
except AttributeError:
|
||||
original_arguments = []
|
||||
|
||||
arguments = []
|
||||
for arg_name, arg_type, original_argument in zip(
|
||||
method_type.arg_names[1:], method_type.arg_types[1:], original_arguments
|
||||
):
|
||||
bound_arg_type = semanal_api.anal_type(arg_type)
|
||||
bound_arg_type = bind_or_analyze_type(arg_type, semanal_api, original_module_name)
|
||||
if bound_arg_type is None:
|
||||
return
|
||||
|
||||
@@ -406,4 +421,10 @@ def copy_method_to_another_class(
|
||||
argument.set_line(original_argument)
|
||||
arguments.append(argument)
|
||||
|
||||
add_method(ctx, new_method_name, args=arguments, return_type=bound_return_type, self_type=self_type)
|
||||
add_method(ctx, new_method_name, args=arguments, return_type=return_type, self_type=self_type)
|
||||
|
||||
|
||||
def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None:
|
||||
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
|
||||
if sym is not None and isinstance(sym.node, TypeInfo):
|
||||
get_django_metadata(sym.node)["manager_bases"][fullname] = 1
|
||||
|
||||
@@ -53,10 +53,8 @@ def transform_form_class(ctx: ClassDefContext) -> None:
|
||||
forms.make_meta_nested_class_inherit_from_any(ctx)
|
||||
|
||||
|
||||
def add_new_manager_base(ctx: ClassDefContext) -> None:
|
||||
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
|
||||
if sym is not None and isinstance(sym.node, TypeInfo):
|
||||
helpers.get_django_metadata(sym.node)["manager_bases"][ctx.cls.fullname] = 1
|
||||
def add_new_manager_base_hook(ctx: ClassDefContext) -> None:
|
||||
helpers.add_new_manager_base(ctx.api, ctx.cls.fullname)
|
||||
|
||||
|
||||
def extract_django_settings_module(config_file_path: Optional[str]) -> str:
|
||||
@@ -235,7 +233,12 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
related_model_module = related_model_cls.__module__
|
||||
if related_model_module != file.fullname:
|
||||
deps.add(self._new_dependency(related_model_module))
|
||||
return list(deps) + [self._new_dependency("django_stubs_ext")] # for annotate
|
||||
return list(deps) + [
|
||||
# for QuerySet.annotate
|
||||
self._new_dependency("django_stubs_ext"),
|
||||
# For BaseManager.from_queryset
|
||||
self._new_dependency("django.db.models.query"),
|
||||
]
|
||||
|
||||
def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], MypyType]]:
|
||||
if fullname == "django.contrib.auth.get_user_model":
|
||||
@@ -305,7 +308,7 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
return partial(transform_model_class, django_context=self.django_context)
|
||||
|
||||
if fullname in self._get_current_manager_bases():
|
||||
return add_new_manager_base
|
||||
return add_new_manager_base_hook
|
||||
|
||||
if fullname in self._get_current_form_bases():
|
||||
return transform_form_class
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from mypy.checker import fill_typevars
|
||||
from mypy.nodes import GDEF, Decorator, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo
|
||||
from mypy.plugin import ClassDefContext, DynamicClassDefContext
|
||||
from mypy.types import AnyType, Instance, TypeOfAny
|
||||
from mypy.types import CallableType, Instance, TypeVarType, UnboundType, get_proper_type
|
||||
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
@@ -29,15 +30,11 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
|
||||
# But it should be analyzed again, so this isn't a problem.
|
||||
return
|
||||
|
||||
base_manager_instance = fill_typevars(base_manager_info)
|
||||
assert isinstance(base_manager_instance, Instance)
|
||||
new_manager_info = semanal_api.basic_new_typeinfo(
|
||||
ctx.name, basetype_or_fallback=Instance(base_manager_info, [AnyType(TypeOfAny.unannotated)]), line=ctx.call.line
|
||||
ctx.name, basetype_or_fallback=base_manager_instance, line=ctx.call.line
|
||||
)
|
||||
new_manager_info.line = ctx.call.line
|
||||
new_manager_info.defn.line = ctx.call.line
|
||||
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
|
||||
|
||||
current_module = semanal_api.cur_mod_node
|
||||
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
|
||||
|
||||
sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname)
|
||||
assert sym is not None
|
||||
@@ -52,6 +49,15 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
|
||||
derived_queryset_info = sym.node
|
||||
assert isinstance(derived_queryset_info, TypeInfo)
|
||||
|
||||
new_manager_info.line = ctx.call.line
|
||||
new_manager_info.type_vars = base_manager_info.type_vars
|
||||
new_manager_info.defn.type_vars = base_manager_info.defn.type_vars
|
||||
new_manager_info.defn.line = ctx.call.line
|
||||
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
|
||||
|
||||
current_module = semanal_api.cur_mod_node
|
||||
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
|
||||
|
||||
if len(ctx.call.args) > 1:
|
||||
expr = ctx.call.args[1]
|
||||
assert isinstance(expr, StrExpr)
|
||||
@@ -64,11 +70,19 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
|
||||
base_manager_info.metadata["from_queryset_managers"] = {}
|
||||
base_manager_info.metadata["from_queryset_managers"][custom_manager_generated_fullname] = new_manager_info.fullname
|
||||
|
||||
# So that the plugin will reparameterize the manager when it is constructed inside of a Model definition
|
||||
helpers.add_new_manager_base(semanal_api, new_manager_info.fullname)
|
||||
|
||||
class_def_context = ClassDefContext(cls=new_manager_info.defn, reason=ctx.call, api=semanal_api)
|
||||
self_type = Instance(new_manager_info, [])
|
||||
self_type = fill_typevars(new_manager_info)
|
||||
assert isinstance(self_type, Instance)
|
||||
queryset_method_names = []
|
||||
|
||||
# we need to copy all methods in MRO before django.db.models.query.QuerySet
|
||||
for class_mro_info in derived_queryset_info.mro:
|
||||
if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME:
|
||||
for name, sym in class_mro_info.names.items():
|
||||
queryset_method_names.append(name)
|
||||
break
|
||||
for name, sym in class_mro_info.names.items():
|
||||
if isinstance(sym.node, FuncDef):
|
||||
@@ -80,3 +94,59 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
|
||||
helpers.copy_method_to_another_class(
|
||||
class_def_context, self_type, new_method_name=name, method_node=func_node
|
||||
)
|
||||
|
||||
# Gather names of all BaseManager methods
|
||||
manager_method_names = []
|
||||
for manager_mro_info in new_manager_info.mro:
|
||||
if manager_mro_info.fullname == fullnames.BASE_MANAGER_CLASS_FULLNAME:
|
||||
for name, sym in manager_mro_info.names.items():
|
||||
manager_method_names.append(name)
|
||||
|
||||
# Copy/alter all methods in common between BaseManager/QuerySet over to the new manager if their return type is
|
||||
# the QuerySet's self-type. Alter the return type to be the custom queryset, parameterized by the manager's model
|
||||
# type variable.
|
||||
for class_mro_info in derived_queryset_info.mro:
|
||||
if class_mro_info.fullname != fullnames.QUERYSET_CLASS_FULLNAME:
|
||||
continue
|
||||
for name, sym in class_mro_info.names.items():
|
||||
if name not in manager_method_names:
|
||||
continue
|
||||
|
||||
if isinstance(sym.node, FuncDef):
|
||||
func_node = sym.node
|
||||
elif isinstance(sym.node, Decorator):
|
||||
func_node = sym.node.func
|
||||
else:
|
||||
continue
|
||||
|
||||
method_type = func_node.type
|
||||
if not isinstance(method_type, CallableType):
|
||||
if not semanal_api.final_iteration:
|
||||
semanal_api.defer()
|
||||
return None
|
||||
original_return_type = method_type.ret_type
|
||||
if original_return_type is None:
|
||||
continue
|
||||
|
||||
# Skip any method that doesn't return _QS
|
||||
original_return_type = get_proper_type(original_return_type)
|
||||
if isinstance(original_return_type, UnboundType):
|
||||
if original_return_type.name != "_QS":
|
||||
continue
|
||||
elif isinstance(original_return_type, TypeVarType):
|
||||
if original_return_type.name != "_QS":
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
# Return the custom queryset parameterized by the manager's type vars
|
||||
return_type = Instance(derived_queryset_info, self_type.args)
|
||||
|
||||
helpers.copy_method_to_another_class(
|
||||
class_def_context,
|
||||
self_type,
|
||||
new_method_name=name,
|
||||
method_node=func_node,
|
||||
return_type=return_type,
|
||||
original_module_name=class_mro_info.module_name,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
- case: from_queryset_with_base_manager
|
||||
main: |
|
||||
from myapp.models import MyModel
|
||||
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
|
||||
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
|
||||
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
|
||||
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
|
||||
reveal_type(MyModel.objects.filter(id=1).queryset_method()) # N: Revealed type is "builtins.str"
|
||||
reveal_type(MyModel.objects.filter(id=1)) # N: Revealed type is "myapp.models.ModelQuerySet[myapp.models.MyModel*]"
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
@@ -23,7 +25,7 @@
|
||||
- case: from_queryset_with_manager
|
||||
main: |
|
||||
from myapp.models import MyModel
|
||||
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
|
||||
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
|
||||
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
|
||||
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
|
||||
installed_apps:
|
||||
@@ -97,7 +99,7 @@
|
||||
- case: from_queryset_with_class_inheritance
|
||||
main: |
|
||||
from myapp.models import MyModel
|
||||
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
|
||||
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
|
||||
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
|
||||
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
|
||||
installed_apps:
|
||||
@@ -121,7 +123,7 @@
|
||||
- case: from_queryset_with_manager_in_another_directory_and_imports
|
||||
main: |
|
||||
from myapp.models import MyModel
|
||||
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
|
||||
reveal_type(MyModel().objects) # N: Revealed type is "myapp.managers.NewManager[myapp.models.MyModel]"
|
||||
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
|
||||
reveal_type(MyModel().objects.queryset_method) # N: Revealed type is "def (param: Union[builtins.str, None] =) -> Union[builtins.str, None]"
|
||||
reveal_type(MyModel().objects.queryset_method('str')) # N: Revealed type is "Union[builtins.str, None]"
|
||||
@@ -151,7 +153,7 @@
|
||||
disable_cache: true
|
||||
main: |
|
||||
from myapp.models import MyModel
|
||||
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
|
||||
reveal_type(MyModel().objects) # N: Revealed type is "myapp.managers.NewManager[myapp.models.MyModel]"
|
||||
reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*"
|
||||
reveal_type(MyModel().objects.base_queryset_method) # N: Revealed type is "def (param: Union[builtins.int, builtins.str]) -> <nothing>"
|
||||
reveal_type(MyModel().objects.base_queryset_method(2)) # N: Revealed type is "<nothing>"
|
||||
@@ -183,7 +185,7 @@
|
||||
- case: from_queryset_with_decorated_queryset_methods
|
||||
main: |
|
||||
from myapp.models import MyModel
|
||||
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]"
|
||||
reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
|
||||
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str"
|
||||
installed_apps:
|
||||
- myapp
|
||||
|
||||
Reference in New Issue
Block a user