From 220f0e4cf09df170427bd83ee63f7cbe84f223f5 Mon Sep 17 00:00:00 2001 From: Petter Friberg Date: Sat, 22 Jan 2022 19:13:05 +0100 Subject: [PATCH] Reuse reverse managers instead of recreating (#825) --- mypy_django_plugin/lib/helpers.py | 18 +++- mypy_django_plugin/transformers/models.py | 108 ++++++++++++++-------- tests/typecheck/fields/test_related.yml | 96 +++++++++++++++++-- 3 files changed, 170 insertions(+), 52 deletions(-) diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 63dc0c2..d4ddfb1 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -209,7 +209,11 @@ def is_annotated_model_fullname(model_cls_fullname: str) -> bool: def add_new_class_for_module( - module: MypyFile, name: str, bases: List[Instance], fields: Optional[Dict[str, MypyType]] = None + module: MypyFile, + name: str, + bases: List[Instance], + fields: Optional[Dict[str, MypyType]] = None, + no_serialize: bool = False, ) -> TypeInfo: new_class_unique_name = checker.gen_unique_name(name, module.names) @@ -229,10 +233,14 @@ def add_new_class_for_module( var = Var(field_name, type=field_type) var.info = new_typeinfo var._fullname = new_typeinfo.fullname + "." + field_name - new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True) + new_typeinfo.names[field_name] = SymbolTableNode( + MDEF, var, plugin_generated=True, no_serialize=no_serialize + ) classdef.info = new_typeinfo - module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True) + module.names[new_class_unique_name] = SymbolTableNode( + GDEF, new_typeinfo, plugin_generated=True, no_serialize=no_serialize + ) return new_typeinfo @@ -331,7 +339,7 @@ def check_types_compatible( api.check_subtype(actual_type, expected_type, ctx.context, error_message, "got", "expected") -def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None: +def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType, no_serialize: bool = False) -> None: # type=: type of the variable itself var = Var(name=name, type=sym_type) # var.info: type of the object variable is bound to @@ -339,7 +347,7 @@ def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> No var._fullname = info.fullname + "." + name var.is_initialized_in_class = True var.is_inferred = True - info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True) + info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True, no_serialize=no_serialize) def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument], MypyType]: diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index b689f77..730a4d3 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -54,8 +54,8 @@ class ModelClassInitializer: var.is_inferred = True return var - def add_new_node_to_model_class(self, name: str, typ: MypyType) -> None: - helpers.add_new_sym_for_info(self.model_classdef.info, name=name, sym_type=typ) + def add_new_node_to_model_class(self, name: str, typ: MypyType, no_serialize: bool = False) -> None: + helpers.add_new_sym_for_info(self.model_classdef.info, name=name, sym_type=typ, no_serialize=no_serialize) def add_new_class_for_current_module(self, name: str, bases: List[Instance]) -> TypeInfo: current_module = self.api.modules[self.model_classdef.info.module_name] @@ -229,21 +229,22 @@ class AddManagers(ModelClassInitializer): def run_with_model_cls(self, model_cls: Type[Model]) -> None: manager_info: Optional[TypeInfo] + encountered_incomplete_manager_def = False for manager_name, manager in model_cls._meta.managers_map.items(): manager_class_name = manager.__class__.__name__ manager_fullname = helpers.get_class_fullname(manager.__class__) try: manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname) except helpers.IncompleteDefnException as exc: - if not self.api.final_iteration: - raise exc - else: - # On final round, see if we can find info for a generated (dynamic class) manager - base_manager_fullname = helpers.get_class_fullname(manager.__class__.__bases__[0]) - manager_info = self.get_generated_manager_info(manager_fullname, base_manager_fullname) - if manager_info is None: - continue - _, manager_class_name = manager_info.fullname.rsplit(".", maxsplit=1) + # Check if manager is a generated (dynamic class) manager + base_manager_fullname = helpers.get_class_fullname(manager.__class__.__bases__[0]) + manager_info = self.get_generated_manager_info(manager_fullname, base_manager_fullname) + if manager_info is None: + # Manager doesn't appear to be generated. Track that we encountered an + # incomplete definition and skip + encountered_incomplete_manager_def = True + continue + _, manager_class_name = manager_info.fullname.rsplit(".", maxsplit=1) if manager_name not in self.model_classdef.info.names: manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])]) @@ -269,6 +270,10 @@ class AddManagers(ModelClassInitializer): self.add_new_node_to_model_class(manager_name, custom_manager_type) + if encountered_incomplete_manager_def and not self.api.final_iteration: + # Unless we're on the final round, see if another round could figuring out all manager types + raise helpers.IncompleteDefnException() + class AddDefaultManagerAttribute(ModelClassInitializer): def run_with_model_cls(self, model_cls: Type[Model]) -> None: @@ -297,12 +302,26 @@ class AddDefaultManagerAttribute(ModelClassInitializer): class AddRelatedManagers(ModelClassInitializer): + def get_reverse_manager_info(self, model_info: TypeInfo, derived_from: str) -> Optional[TypeInfo]: + manager_fullname = helpers.get_django_metadata(model_info).get("reverse_managers", {}).get(derived_from) + if not manager_fullname: + return None + + symbol = self.api.lookup_fully_qualified_or_none(manager_fullname) + if symbol is None or not isinstance(symbol.node, TypeInfo): + return None + return symbol.node + + def set_reverse_manager_info(self, model_info: TypeInfo, derived_from: str, fullname: str) -> None: + helpers.get_django_metadata(model_info).setdefault("reverse_managers", {})[derived_from] = fullname + def run_with_model_cls(self, model_cls: Type[Model]) -> None: # add related managers for relation in self.django_context.get_model_relations(model_cls): attname = relation.get_accessor_name() - if attname is None: - # no reverse accessor + if attname is None or attname in self.model_classdef.info.names: + # No reverse accessor or already declared. Note that this would also leave any + # explicitly declared(i.e. non-inferred) reverse accessors alone continue related_model_cls = self.django_context.get_field_related_model_cls(relation) @@ -326,7 +345,7 @@ class AddRelatedManagers(ModelClassInitializer): related_manager_info = self.lookup_typeinfo_or_incomplete_defn_error( fullnames.RELATED_MANAGER_CLASS ) # noqa: E501 - default_manager = related_model_info.get("_default_manager") + default_manager = related_model_info.names.get("_default_manager") if not default_manager: raise helpers.IncompleteDefnException() except helpers.IncompleteDefnException as exc: @@ -335,39 +354,46 @@ class AddRelatedManagers(ModelClassInitializer): else: continue - # create new RelatedManager subclass + # Check if the related model has a related manager subclassed from the default manager + # TODO: Support other reverse managers than `_default_manager` + default_reverse_manager_info = self.get_reverse_manager_info( + model_info=related_model_info, derived_from="_default_manager" + ) + if default_reverse_manager_info: + self.add_new_node_to_model_class( + attname, Instance(default_reverse_manager_info, []), no_serialize=True + ) + return + + # The reverse manager we're looking for doesn't exist. So we create it. + # The (default) reverse manager type is built from a RelatedManager and the default manager on the related model parametrized_related_manager_type = Instance(related_manager_info, [Instance(related_model_info, [])]) default_manager_type = default_manager.type - if default_manager_type is None: - default_manager_type = self.try_generate_related_manager(related_model_cls, related_model_info) - if ( - default_manager_type is None - or not isinstance(default_manager_type, Instance) - or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME - ): + assert default_manager_type is not None + assert isinstance(default_manager_type, Instance) + # When the default manager isn't custom there's no need to create a new type + # as `RelatedManager` has `models.Manager` as base + if default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME: self.add_new_node_to_model_class(attname, parametrized_related_manager_type) continue - name = model_cls.__name__ + "_" + related_model_cls.__name__ + "_" + "RelatedManager" - bases = [parametrized_related_manager_type, default_manager_type] - new_related_manager_info = self.add_new_class_for_current_module(name, bases) + # The reverse manager is based on the related model's manager, so it makes most sense to add the new + # related manager in that module + new_related_manager_info = helpers.add_new_class_for_module( + module=self.api.modules[related_model_info.module_name], + name=f"{related_model_cls.__name__}_RelatedManager", + bases=[parametrized_related_manager_type, default_manager_type], + no_serialize=True, + ) new_related_manager_info.metadata["django"] = {"related_manager_to_model": related_model_info.fullname} - - self.add_new_node_to_model_class(attname, Instance(new_related_manager_info, [])) - - def try_generate_related_manager( - self, related_model_cls: Type[Model], related_model_info: TypeInfo - ) -> Optional[Instance]: - manager = related_model_cls._meta.managers_map["_default_manager"] - base_manager_fullname = helpers.get_class_fullname(manager.__class__.__bases__[0]) - manager_fullname = helpers.get_class_fullname(manager.__class__) - generated_managers = self.get_generated_manager_mappings(base_manager_fullname) - if manager_fullname in generated_managers: - real_manager_fullname = generated_managers[manager_fullname] - manager_info = self.lookup_typeinfo(real_manager_fullname) - if manager_info: - return Instance(manager_info, [Instance(related_model_info, [])]) - return None + # Stash the new reverse manager type fullname on the related model, so we don't duplicate + # or have to create it again for other reverse relations + self.set_reverse_manager_info( + related_model_info, + derived_from="_default_manager", + fullname=new_related_manager_info.fullname, + ) + self.add_new_node_to_model_class(attname, Instance(new_related_manager_info, []), no_serialize=True) class AddExtraFieldMethods(ModelClassInitializer): diff --git a/tests/typecheck/fields/test_related.yml b/tests/typecheck/fields/test_related.yml index 1120678..1c635a2 100644 --- a/tests/typecheck/fields/test_related.yml +++ b/tests/typecheck/fields/test_related.yml @@ -608,8 +608,8 @@ reveal_type(Article().registered_by_user) # N: Revealed type is "myapp.models.MyUser*" user = MyUser() - reveal_type(user.book_set) # N: Revealed type is "myapp.models.MyUser_Book_RelatedManager1" - reveal_type(user.article_set) # N: Revealed type is "myapp.models.MyUser_Article_RelatedManager1" + reveal_type(user.book_set) # N: Revealed type is "myapp.models.Book_RelatedManager" + reveal_type(user.article_set) # N: Revealed type is "myapp.models.Article_RelatedManager" reveal_type(user.book_set.add) # N: Revealed type is "def (*objs: Union[myapp.models.Book*, builtins.int], *, bulk: builtins.bool =)" reveal_type(user.article_set.add) # N: Revealed type is "def (*objs: Union[myapp.models.Article*, builtins.int], *, bulk: builtins.bool =)" reveal_type(user.book_set.filter) # N: Revealed type is "def (*args: Any, **kwargs: Any) -> myapp.models.LibraryEntityQuerySet[myapp.models.Book*]" @@ -689,18 +689,18 @@ - case: related_manager_is_a_subclass_of_default_manager main: | from myapp.models import User, Order, Product - reveal_type(User().orders) # N: Revealed type is "myapp.models.User_Order_RelatedManager1" + reveal_type(User().orders) # N: Revealed type is "myapp.models.Order_RelatedManager" reveal_type(User().orders.get()) # N: Revealed type is "myapp.models.Order*" reveal_type(User().orders.manager_method()) # N: Revealed type is "builtins.int" reveal_type(Product.objects.queryset_method()) # N: Revealed type is "builtins.int" - reveal_type(Order().products) # N: Revealed type is "myapp.models.Order_Product_RelatedManager1" + reveal_type(Order().products) # N: Revealed type is "myapp.models.Product_RelatedManager" reveal_type(Order().products.get()) # N: Revealed type is "myapp.models.Product*" reveal_type(Order().products.queryset_method()) # N: Revealed type is "builtins.int" - # TODO: realted manager support to use the same type for all related managers if 1 == 2: manager = User().products else: - manager = Order().products # E: Incompatible types in assignment (expression has type "Order_Product_RelatedManager1", variable has type "User_Product_RelatedManager1") + manager = Order().products + reveal_type(manager) # N: Revealed type is "myapp.models.Product_RelatedManager" installed_apps: - myapp files: @@ -725,6 +725,90 @@ order = models.ForeignKey(to=Order, on_delete=models.CASCADE, related_name='products') user = models.ForeignKey(to=User, on_delete=models.CASCADE, related_name='products') +- case: related_manager_shared_between_multiple_relations + main: | + from myapp.models.store import Store + from myapp.models.user import User + reveal_type(Store().purchases) # N: Revealed type is "myapp.models.purchase.Purchase_RelatedManager" + reveal_type(Store().purchases.queryset_method()) # N: Revealed type is "myapp.models.querysets.PurchaseQuerySet" + reveal_type(Store().purchases.filter()) # N: Revealed type is "myapp.models.querysets.PurchaseQuerySet[myapp.models.purchase.Purchase*]" + reveal_type(Store().purchases.filter().queryset_method()) # N: Revealed type is "myapp.models.querysets.PurchaseQuerySet" + reveal_type(User().purchases) # N: Revealed type is "myapp.models.purchase.Purchase_RelatedManager" + reveal_type(User().purchases.queryset_method()) # N: Revealed type is "myapp.models.querysets.PurchaseQuerySet" + reveal_type(User().purchases.filter()) # N: Revealed type is "myapp.models.querysets.PurchaseQuerySet[myapp.models.purchase.Purchase*]" + reveal_type(User().purchases.filter().queryset_method()) # N: Revealed type is "myapp.models.querysets.PurchaseQuerySet" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models/__init__.py + content: | + from .purchase import Purchase + from .store import Store + from .user import User + - path: myapp/models/store.py + content: | + from django.db import models + + class Store(models.Model): + ... + - path: myapp/models/user.py + content: | + from django.db import models + + class User(models.Model): + ... + - path: myapp/models/querysets.py + content: | + from django.db.models import QuerySet + from typing import TYPE_CHECKING + if TYPE_CHECKING: + from .purchase import Purchase + + class PurchaseQuerySet(QuerySet['Purchase']): + def queryset_method(self) -> "PurchaseQuerySet": + return self.all() + - path: myapp/models/purchase.py + content: | + from django.db import models + from django.db.models.manager import BaseManager + from .querysets import PurchaseQuerySet + from .store import Store + from .user import User + + PurchaseManager = BaseManager.from_queryset(PurchaseQuerySet) + class Purchase(models.Model): + objects = PurchaseManager() + store = models.ForeignKey(to=Store, on_delete=models.CASCADE, related_name='purchases') + user = models.ForeignKey(to=User, on_delete=models.CASCADE, related_name='purchases') + +- case: explicitly_declared_related_manager_is_not_overridden + main: | + from myapp.models import User + reveal_type(User().purchases) # N: Revealed type is "builtins.int" + User().purchases.filter() # E: "int" has no attribute "filter" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + from django.db.models.manager import BaseManager + + class User(models.Model): + purchases: int + + class PurchaseQuerySet(models.QuerySet['Purchase']): + def queryset_method(self) -> "PurchaseQuerySet": + return self.all() + + PurchaseManager = BaseManager.from_queryset(PurchaseQuerySet) + class Purchase(models.Model): + objects = PurchaseManager() + user = models.ForeignKey(to=User, on_delete=models.CASCADE, related_name='purchases') + + - case: related_manager_no_conflict_from_star_import main: | import myapp.models