diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 198916d..70c8471 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -18,7 +18,6 @@ from mypy.nodes import ( MemberExpr, MypyFile, NameExpr, - PlaceholderNode, StrExpr, SymbolNode, SymbolTable, @@ -33,12 +32,13 @@ from mypy.plugin import ( DynamicClassDefContext, FunctionContext, MethodContext, + SemanticAnalyzerPluginInterface, ) from mypy.plugins.common import add_method from mypy.semanal import SemanticAnalyzer from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType from mypy.types import Type as MypyType -from mypy.types import TypedDictType, TypeOfAny, UnionType +from mypy.types import TypedDictType, TypeOfAny, UnboundType, UnionType from mypy_django_plugin.lib import fullnames from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME @@ -355,8 +355,26 @@ def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument], return prepared_arguments, return_type +def bind_or_analyze_type(t: MypyType, api: SemanticAnalyzer, module_name: Optional[str] = None) -> Optional[MypyType]: + """Analyze a type. If an unbound type, try to look it up in the given module name. + + That should hopefully give a bound type.""" + if isinstance(t, UnboundType) and module_name is not None: + node = api.lookup_fully_qualified_or_none(module_name + "." + t.name) + if node is None: + return None + return node.type + else: + return api.anal_type(t) + + def copy_method_to_another_class( - ctx: ClassDefContext, self_type: Instance, new_method_name: str, method_node: FuncDef + ctx: ClassDefContext, + self_type: Instance, + new_method_name: str, + method_node: FuncDef, + return_type: Optional[MypyType] = None, + original_module_name: Optional[str] = None, ) -> None: semanal_api = get_semanal_api(ctx) if method_node.type is None: @@ -374,23 +392,20 @@ def copy_method_to_another_class( semanal_api.defer() return - arguments = [] - bound_return_type = semanal_api.anal_type(method_type.ret_type, allow_placeholder=True) - - assert bound_return_type is not None - - if isinstance(bound_return_type, PlaceholderNode): + if return_type is None: + return_type = bind_or_analyze_type(method_type.ret_type, semanal_api, original_module_name) + if return_type is None: return - try: original_arguments = method_node.arguments[1:] except AttributeError: original_arguments = [] + arguments = [] for arg_name, arg_type, original_argument in zip( method_type.arg_names[1:], method_type.arg_types[1:], original_arguments ): - bound_arg_type = semanal_api.anal_type(arg_type) + bound_arg_type = bind_or_analyze_type(arg_type, semanal_api, original_module_name) if bound_arg_type is None: return @@ -406,4 +421,10 @@ def copy_method_to_another_class( argument.set_line(original_argument) arguments.append(argument) - add_method(ctx, new_method_name, args=arguments, return_type=bound_return_type, self_type=self_type) + add_method(ctx, new_method_name, args=arguments, return_type=return_type, self_type=self_type) + + +def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None: + sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME) + if sym is not None and isinstance(sym.node, TypeInfo): + get_django_metadata(sym.node)["manager_bases"][fullname] = 1 diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 6baa7cf..bf608f6 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -53,10 +53,8 @@ def transform_form_class(ctx: ClassDefContext) -> None: forms.make_meta_nested_class_inherit_from_any(ctx) -def add_new_manager_base(ctx: ClassDefContext) -> None: - sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME) - if sym is not None and isinstance(sym.node, TypeInfo): - helpers.get_django_metadata(sym.node)["manager_bases"][ctx.cls.fullname] = 1 +def add_new_manager_base_hook(ctx: ClassDefContext) -> None: + helpers.add_new_manager_base(ctx.api, ctx.cls.fullname) def extract_django_settings_module(config_file_path: Optional[str]) -> str: @@ -235,7 +233,12 @@ class NewSemanalDjangoPlugin(Plugin): related_model_module = related_model_cls.__module__ if related_model_module != file.fullname: deps.add(self._new_dependency(related_model_module)) - return list(deps) + [self._new_dependency("django_stubs_ext")] # for annotate + return list(deps) + [ + # for QuerySet.annotate + self._new_dependency("django_stubs_ext"), + # For BaseManager.from_queryset + self._new_dependency("django.db.models.query"), + ] def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], MypyType]]: if fullname == "django.contrib.auth.get_user_model": @@ -305,7 +308,7 @@ class NewSemanalDjangoPlugin(Plugin): return partial(transform_model_class, django_context=self.django_context) if fullname in self._get_current_manager_bases(): - return add_new_manager_base + return add_new_manager_base_hook if fullname in self._get_current_form_bases(): return transform_form_class diff --git a/mypy_django_plugin/transformers/managers.py b/mypy_django_plugin/transformers/managers.py index 13bb39b..681c32b 100644 --- a/mypy_django_plugin/transformers/managers.py +++ b/mypy_django_plugin/transformers/managers.py @@ -1,6 +1,7 @@ +from mypy.checker import fill_typevars from mypy.nodes import GDEF, Decorator, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo from mypy.plugin import ClassDefContext, DynamicClassDefContext -from mypy.types import AnyType, Instance, TypeOfAny +from mypy.types import CallableType, Instance, TypeVarType, UnboundType, get_proper_type from mypy_django_plugin.lib import fullnames, helpers @@ -29,15 +30,11 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte # But it should be analyzed again, so this isn't a problem. return + base_manager_instance = fill_typevars(base_manager_info) + assert isinstance(base_manager_instance, Instance) new_manager_info = semanal_api.basic_new_typeinfo( - ctx.name, basetype_or_fallback=Instance(base_manager_info, [AnyType(TypeOfAny.unannotated)]), line=ctx.call.line + ctx.name, basetype_or_fallback=base_manager_instance, line=ctx.call.line ) - new_manager_info.line = ctx.call.line - new_manager_info.defn.line = ctx.call.line - new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type() - - current_module = semanal_api.cur_mod_node - current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True) sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname) assert sym is not None @@ -52,6 +49,15 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte derived_queryset_info = sym.node assert isinstance(derived_queryset_info, TypeInfo) + new_manager_info.line = ctx.call.line + new_manager_info.type_vars = base_manager_info.type_vars + new_manager_info.defn.type_vars = base_manager_info.defn.type_vars + new_manager_info.defn.line = ctx.call.line + new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type() + + current_module = semanal_api.cur_mod_node + current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True) + if len(ctx.call.args) > 1: expr = ctx.call.args[1] assert isinstance(expr, StrExpr) @@ -64,11 +70,19 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte base_manager_info.metadata["from_queryset_managers"] = {} base_manager_info.metadata["from_queryset_managers"][custom_manager_generated_fullname] = new_manager_info.fullname + # So that the plugin will reparameterize the manager when it is constructed inside of a Model definition + helpers.add_new_manager_base(semanal_api, new_manager_info.fullname) + class_def_context = ClassDefContext(cls=new_manager_info.defn, reason=ctx.call, api=semanal_api) - self_type = Instance(new_manager_info, []) + self_type = fill_typevars(new_manager_info) + assert isinstance(self_type, Instance) + queryset_method_names = [] + # we need to copy all methods in MRO before django.db.models.query.QuerySet for class_mro_info in derived_queryset_info.mro: if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME: + for name, sym in class_mro_info.names.items(): + queryset_method_names.append(name) break for name, sym in class_mro_info.names.items(): if isinstance(sym.node, FuncDef): @@ -80,3 +94,59 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte helpers.copy_method_to_another_class( class_def_context, self_type, new_method_name=name, method_node=func_node ) + + # Gather names of all BaseManager methods + manager_method_names = [] + for manager_mro_info in new_manager_info.mro: + if manager_mro_info.fullname == fullnames.BASE_MANAGER_CLASS_FULLNAME: + for name, sym in manager_mro_info.names.items(): + manager_method_names.append(name) + + # Copy/alter all methods in common between BaseManager/QuerySet over to the new manager if their return type is + # the QuerySet's self-type. Alter the return type to be the custom queryset, parameterized by the manager's model + # type variable. + for class_mro_info in derived_queryset_info.mro: + if class_mro_info.fullname != fullnames.QUERYSET_CLASS_FULLNAME: + continue + for name, sym in class_mro_info.names.items(): + if name not in manager_method_names: + continue + + if isinstance(sym.node, FuncDef): + func_node = sym.node + elif isinstance(sym.node, Decorator): + func_node = sym.node.func + else: + continue + + method_type = func_node.type + if not isinstance(method_type, CallableType): + if not semanal_api.final_iteration: + semanal_api.defer() + return None + original_return_type = method_type.ret_type + if original_return_type is None: + continue + + # Skip any method that doesn't return _QS + original_return_type = get_proper_type(original_return_type) + if isinstance(original_return_type, UnboundType): + if original_return_type.name != "_QS": + continue + elif isinstance(original_return_type, TypeVarType): + if original_return_type.name != "_QS": + continue + else: + continue + + # Return the custom queryset parameterized by the manager's type vars + return_type = Instance(derived_queryset_info, self_type.args) + + helpers.copy_method_to_another_class( + class_def_context, + self_type, + new_method_name=name, + method_node=func_node, + return_type=return_type, + original_module_name=class_mro_info.module_name, + ) diff --git a/tests/typecheck/managers/querysets/test_from_queryset.yml b/tests/typecheck/managers/querysets/test_from_queryset.yml index 53dceac..5dd4551 100644 --- a/tests/typecheck/managers/querysets/test_from_queryset.yml +++ b/tests/typecheck/managers/querysets/test_from_queryset.yml @@ -1,9 +1,11 @@ - case: from_queryset_with_base_manager main: | from myapp.models import MyModel - reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" + reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*" reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str" + reveal_type(MyModel.objects.filter(id=1).queryset_method()) # N: Revealed type is "builtins.str" + reveal_type(MyModel.objects.filter(id=1)) # N: Revealed type is "myapp.models.ModelQuerySet[myapp.models.MyModel*]" installed_apps: - myapp files: @@ -23,7 +25,7 @@ - case: from_queryset_with_manager main: | from myapp.models import MyModel - reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" + reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*" reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str" installed_apps: @@ -97,7 +99,7 @@ - case: from_queryset_with_class_inheritance main: | from myapp.models import MyModel - reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" + reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*" reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str" installed_apps: @@ -121,7 +123,7 @@ - case: from_queryset_with_manager_in_another_directory_and_imports main: | from myapp.models import MyModel - reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" + reveal_type(MyModel().objects) # N: Revealed type is "myapp.managers.NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*" reveal_type(MyModel().objects.queryset_method) # N: Revealed type is "def (param: Union[builtins.str, None] =) -> Union[builtins.str, None]" reveal_type(MyModel().objects.queryset_method('str')) # N: Revealed type is "Union[builtins.str, None]" @@ -151,7 +153,7 @@ disable_cache: true main: | from myapp.models import MyModel - reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" + reveal_type(MyModel().objects) # N: Revealed type is "myapp.managers.NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*" reveal_type(MyModel().objects.base_queryset_method) # N: Revealed type is "def (param: Union[builtins.int, builtins.str]) -> " reveal_type(MyModel().objects.base_queryset_method(2)) # N: Revealed type is "" @@ -183,7 +185,7 @@ - case: from_queryset_with_decorated_queryset_methods main: | from myapp.models import MyModel - reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" + reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str" installed_apps: - myapp