diff --git a/django-stubs/db/models/manager.pyi b/django-stubs/db/models/manager.pyi index cffede8..0dfc2a1 100644 --- a/django-stubs/db/models/manager.pyi +++ b/django-stubs/db/models/manager.pyi @@ -9,8 +9,9 @@ class BaseManager(QuerySet[_T, _T]): creation_counter: int = ... auto_created: bool = ... use_in_migrations: bool = ... - model: Optional[Any] = ... - name: Optional[Any] = ... + name: str = ... + model: Type[Model] = ... + db: str def __init__(self) -> None: ... def deconstruct(self) -> Tuple[bool, str, None, Tuple, Dict[str, int]]: ... def check(self, **kwargs: Any) -> List[Any]: ... @@ -34,8 +35,4 @@ class ManagerDescriptor: def __get__(self, instance: Optional[Model], cls: Type[Model] = ...) -> Manager: ... class EmptyManager(Manager): - creation_counter: int - name: None - model: Optional[Type[Model]] = ... def __init__(self, model: Type[Model]) -> None: ... - def get_queryset(self) -> QuerySet: ... diff --git a/django-stubs/db/models/options.pyi b/django-stubs/db/models/options.pyi index 8ab108b..4d2f612 100644 --- a/django-stubs/db/models/options.pyi +++ b/django-stubs/db/models/options.pyi @@ -37,8 +37,6 @@ class Options(Generic[_M]): default_manager: Manager fields: ImmutableList local_concrete_fields: ImmutableList - managers: ImmutableList - managers_map: Dict[str, Manager] related_objects: ImmutableList FORWARD_PROPERTIES: Any = ... REVERSE_PROPERTIES: Any = ... @@ -106,6 +104,10 @@ class Options(Generic[_M]): def many_to_many(self) -> List[ManyToManyField]: ... @property def fields_map(self) -> Dict[str, Union[Field, ForeignObjectRel]]: ... + @property + def managers(self) -> List[Manager]: ... + @property + def managers_map(self) -> Dict[str, Manager]: ... def get_field(self, field_name: Union[Callable, str]) -> Field: ... def get_base_chain(self, model: Type[Model]) -> List[Type[Model]]: ... def get_parent_list(self) -> List[Type[Model]]: ... diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index b0277fc..422af18 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -131,17 +131,16 @@ def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo] return None -def add_new_class_for_current_module(api: TypeChecker, name: str, bases: List[Instance], - fields: 'OrderedDict[str, MypyType]') -> TypeInfo: - current_module = api.scope.stack[0] - new_class_unique_name = checker.gen_unique_name(name, current_module.names) +def add_new_class_for_module(module: MypyFile, name: str, bases: List[Instance], + fields: 'OrderedDict[str, MypyType]') -> TypeInfo: + new_class_unique_name = checker.gen_unique_name(name, module.names) # make new class expression classdef = ClassDef(new_class_unique_name, Block([])) - classdef.fullname = current_module.fullname() + '.' + new_class_unique_name + classdef.fullname = module.fullname() + '.' + new_class_unique_name # make new TypeInfo - new_typeinfo = TypeInfo(SymbolTable(), classdef, current_module.fullname()) + new_typeinfo = TypeInfo(SymbolTable(), classdef, module.fullname()) new_typeinfo.bases = bases calculate_mro(new_typeinfo) new_typeinfo.calculate_metaclass_type() @@ -160,14 +159,15 @@ def add_new_class_for_current_module(api: TypeChecker, name: str, bases: List[In add_field_to_new_typeinfo(var_item, is_property=True) classdef.info = new_typeinfo - current_module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True) + module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True) return new_typeinfo def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType: - namedtuple_info = add_new_class_for_current_module(api, name, - bases=[api.named_generic_type('typing.NamedTuple', [])], - fields=fields) + current_module = api.scope.stack[0] + namedtuple_info = add_new_class_for_module(current_module, name, + bases=[api.named_generic_type('typing.NamedTuple', [])], + fields=fields) return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, [])) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 42b8e5c..c2018bf 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import Type, cast +from typing import Type, cast, OrderedDict from django.db.models.base import Model from django.db.models.fields.related import ForeignKey @@ -35,7 +35,7 @@ class ModelClassInitializer: field_info = self.lookup_typeinfo_or_incomplete_defn_error(fullname) return field_info - def add_new_node_to_model_class(self, name: str, typ: Instance) -> None: + def create_new_var(self, name: str, typ: Instance, is_classvar=False) -> Var: # type=: type of the variable itself var = Var(name=name, type=typ) # var.info: type of the object variable is bound to @@ -43,6 +43,11 @@ class ModelClassInitializer: var._fullname = self.model_classdef.info.fullname() + '.' + name var.is_initialized_in_class = True var.is_inferred = True + var.is_classvar = is_classvar + return var + + def add_new_node_to_model_class(self, name: str, typ: Instance, is_classvar=False) -> None: + var = self.create_new_var(name, typ, is_classvar=is_classvar) self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True) def run(self) -> None: @@ -101,21 +106,45 @@ class AddRelatedModelsId(ModelClassInitializer): class AddManagers(ModelClassInitializer): + def _is_manager_any(self, typ: Instance) -> bool: + return typ.type.fullname() == fullnames.MANAGER_CLASS_FULLNAME and type(typ.args[0]) == AnyType + def run_with_model_cls(self, model_cls: Type[Model]) -> None: for manager_name, manager in model_cls._meta.managers_map.items(): - if manager_name not in self.model_classdef.info.names: - manager_fullname = helpers.get_class_fullname(manager.__class__) - manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname) + manager_fullname = helpers.get_class_fullname(manager.__class__) + manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname) + if manager_name not in self.model_classdef.info.names: manager = Instance(manager_info, [Instance(self.model_classdef.info, [])]) - self.add_new_node_to_model_class(manager_name, manager) + self.add_new_node_to_model_class(manager_name, manager, is_classvar=True) + else: + # create new MODELNAME_MANAGERCLASSNAME class that represents manager parametrized with current model + has_manager_any_base = any(self._is_manager_any(base) for base in manager_info.bases) + if has_manager_any_base: + custom_model_manager_name = manager.model.__name__ + '_' + manager.__class__.__name__ + bases = [] + for original_base in manager_info.bases: + if self._is_manager_any(original_base): + if original_base.type is None: + if not self.api.final_iteration: + self.api.defer() + original_base = helpers.reparametrize_instance(original_base, + [Instance(self.model_classdef.info, [])]) + bases.append(original_base) + current_module = self.api.modules[self.model_classdef.info.module_name] + custom_manager_info = helpers.add_new_class_for_module(current_module, + custom_model_manager_name, + bases=bases, + fields=OrderedDict()) + custom_manager_type = Instance(custom_manager_info, [Instance(self.model_classdef.info, [])]) + self.add_new_node_to_model_class(manager_name, custom_manager_type, is_classvar=True) # add _default_manager if '_default_manager' not in self.model_classdef.info.names: default_manager_fullname = helpers.get_class_fullname(model_cls._meta.default_manager.__class__) default_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(default_manager_fullname) default_manager = Instance(default_manager_info, [Instance(self.model_classdef.info, [])]) - self.add_new_node_to_model_class('_default_manager', default_manager) + self.add_new_node_to_model_class('_default_manager', default_manager, is_classvar=True) # add related managers for relation in self.django_context.get_model_relations(model_cls): diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index f7dda60..bc524a1 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -1,11 +1,10 @@ from collections import OrderedDict -from typing import Optional, Tuple, Type, Sequence, List, Union, cast +from typing import List, Optional, Sequence, Tuple, Type, Union, cast from django.core.exceptions import FieldError from django.db.models.base import Model -from django.db.models.fields.related import ForeignKey -from mypy.newsemanal.typeanal import TypeAnalyser -from mypy.nodes import NameExpr, Expression +from mypy.checker import TypeChecker +from mypy.nodes import Expression, NameExpr from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny @@ -41,6 +40,7 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType: if not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME): return ret + api = cast(TypeChecker, ctx.api) return helpers.reparametrize_instance(ret, [Instance(outer_model_info, [])]) diff --git a/test-data/typecheck/managers/test_managers.yml b/test-data/typecheck/managers/test_managers.yml index 1ef2634..46c754e 100644 --- a/test-data/typecheck/managers/test_managers.yml +++ b/test-data/typecheck/managers/test_managers.yml @@ -317,4 +317,21 @@ # class MyManager(models.Manager): # pass # class MyUser(MyBaseUser): -# objects = MyManager() \ No newline at end of file +# objects = MyManager() + +- case: custom_manager_returns_proper_model_types + main: | + from myapp.models import User + reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*' + reveal_type(User.objects.select_related()) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.User*, myapp.models.User*]' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyManager(models.Manager): + pass + class User(models.Model): + objects = MyManager()