create dummy classes for custom_manager_of_model usecase

This commit is contained in:
Maxim Kurnikov
2019-07-21 00:41:21 +03:00
parent 6962b42cba
commit d7d379e1cd
6 changed files with 75 additions and 30 deletions

View File

@@ -9,8 +9,9 @@ class BaseManager(QuerySet[_T, _T]):
creation_counter: int = ...
auto_created: bool = ...
use_in_migrations: bool = ...
model: Optional[Any] = ...
name: Optional[Any] = ...
name: str = ...
model: Type[Model] = ...
db: str
def __init__(self) -> None: ...
def deconstruct(self) -> Tuple[bool, str, None, Tuple, Dict[str, int]]: ...
def check(self, **kwargs: Any) -> List[Any]: ...
@@ -34,8 +35,4 @@ class ManagerDescriptor:
def __get__(self, instance: Optional[Model], cls: Type[Model] = ...) -> Manager: ...
class EmptyManager(Manager):
creation_counter: int
name: None
model: Optional[Type[Model]] = ...
def __init__(self, model: Type[Model]) -> None: ...
def get_queryset(self) -> QuerySet: ...

View File

@@ -37,8 +37,6 @@ class Options(Generic[_M]):
default_manager: Manager
fields: ImmutableList
local_concrete_fields: ImmutableList
managers: ImmutableList
managers_map: Dict[str, Manager]
related_objects: ImmutableList
FORWARD_PROPERTIES: Any = ...
REVERSE_PROPERTIES: Any = ...
@@ -106,6 +104,10 @@ class Options(Generic[_M]):
def many_to_many(self) -> List[ManyToManyField]: ...
@property
def fields_map(self) -> Dict[str, Union[Field, ForeignObjectRel]]: ...
@property
def managers(self) -> List[Manager]: ...
@property
def managers_map(self) -> Dict[str, Manager]: ...
def get_field(self, field_name: Union[Callable, str]) -> Field: ...
def get_base_chain(self, model: Type[Model]) -> List[Type[Model]]: ...
def get_parent_list(self) -> List[Type[Model]]: ...

View File

@@ -131,17 +131,16 @@ def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]
return None
def add_new_class_for_current_module(api: TypeChecker, name: str, bases: List[Instance],
fields: 'OrderedDict[str, MypyType]') -> TypeInfo:
current_module = api.scope.stack[0]
new_class_unique_name = checker.gen_unique_name(name, current_module.names)
def add_new_class_for_module(module: MypyFile, name: str, bases: List[Instance],
fields: 'OrderedDict[str, MypyType]') -> TypeInfo:
new_class_unique_name = checker.gen_unique_name(name, module.names)
# make new class expression
classdef = ClassDef(new_class_unique_name, Block([]))
classdef.fullname = current_module.fullname() + '.' + new_class_unique_name
classdef.fullname = module.fullname() + '.' + new_class_unique_name
# make new TypeInfo
new_typeinfo = TypeInfo(SymbolTable(), classdef, current_module.fullname())
new_typeinfo = TypeInfo(SymbolTable(), classdef, module.fullname())
new_typeinfo.bases = bases
calculate_mro(new_typeinfo)
new_typeinfo.calculate_metaclass_type()
@@ -160,14 +159,15 @@ def add_new_class_for_current_module(api: TypeChecker, name: str, bases: List[In
add_field_to_new_typeinfo(var_item, is_property=True)
classdef.info = new_typeinfo
current_module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)
module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)
return new_typeinfo
def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType:
namedtuple_info = add_new_class_for_current_module(api, name,
bases=[api.named_generic_type('typing.NamedTuple', [])],
fields=fields)
current_module = api.scope.stack[0]
namedtuple_info = add_new_class_for_module(current_module, name,
bases=[api.named_generic_type('typing.NamedTuple', [])],
fields=fields)
return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, []))

View File

@@ -1,5 +1,5 @@
from abc import ABCMeta, abstractmethod
from typing import Type, cast
from typing import Type, cast, OrderedDict
from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey
@@ -35,7 +35,7 @@ class ModelClassInitializer:
field_info = self.lookup_typeinfo_or_incomplete_defn_error(fullname)
return field_info
def add_new_node_to_model_class(self, name: str, typ: Instance) -> None:
def create_new_var(self, name: str, typ: Instance, is_classvar=False) -> Var:
# type=: type of the variable itself
var = Var(name=name, type=typ)
# var.info: type of the object variable is bound to
@@ -43,6 +43,11 @@ class ModelClassInitializer:
var._fullname = self.model_classdef.info.fullname() + '.' + name
var.is_initialized_in_class = True
var.is_inferred = True
var.is_classvar = is_classvar
return var
def add_new_node_to_model_class(self, name: str, typ: Instance, is_classvar=False) -> None:
var = self.create_new_var(name, typ, is_classvar=is_classvar)
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True)
def run(self) -> None:
@@ -101,21 +106,45 @@ class AddRelatedModelsId(ModelClassInitializer):
class AddManagers(ModelClassInitializer):
def _is_manager_any(self, typ: Instance) -> bool:
return typ.type.fullname() == fullnames.MANAGER_CLASS_FULLNAME and type(typ.args[0]) == AnyType
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
for manager_name, manager in model_cls._meta.managers_map.items():
if manager_name not in self.model_classdef.info.names:
manager_fullname = helpers.get_class_fullname(manager.__class__)
manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname)
manager_fullname = helpers.get_class_fullname(manager.__class__)
manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname)
if manager_name not in self.model_classdef.info.names:
manager = Instance(manager_info, [Instance(self.model_classdef.info, [])])
self.add_new_node_to_model_class(manager_name, manager)
self.add_new_node_to_model_class(manager_name, manager, is_classvar=True)
else:
# create new MODELNAME_MANAGERCLASSNAME class that represents manager parametrized with current model
has_manager_any_base = any(self._is_manager_any(base) for base in manager_info.bases)
if has_manager_any_base:
custom_model_manager_name = manager.model.__name__ + '_' + manager.__class__.__name__
bases = []
for original_base in manager_info.bases:
if self._is_manager_any(original_base):
if original_base.type is None:
if not self.api.final_iteration:
self.api.defer()
original_base = helpers.reparametrize_instance(original_base,
[Instance(self.model_classdef.info, [])])
bases.append(original_base)
current_module = self.api.modules[self.model_classdef.info.module_name]
custom_manager_info = helpers.add_new_class_for_module(current_module,
custom_model_manager_name,
bases=bases,
fields=OrderedDict())
custom_manager_type = Instance(custom_manager_info, [Instance(self.model_classdef.info, [])])
self.add_new_node_to_model_class(manager_name, custom_manager_type, is_classvar=True)
# add _default_manager
if '_default_manager' not in self.model_classdef.info.names:
default_manager_fullname = helpers.get_class_fullname(model_cls._meta.default_manager.__class__)
default_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(default_manager_fullname)
default_manager = Instance(default_manager_info, [Instance(self.model_classdef.info, [])])
self.add_new_node_to_model_class('_default_manager', default_manager)
self.add_new_node_to_model_class('_default_manager', default_manager, is_classvar=True)
# add related managers
for relation in self.django_context.get_model_relations(model_cls):

View File

@@ -1,11 +1,10 @@
from collections import OrderedDict
from typing import Optional, Tuple, Type, Sequence, List, Union, cast
from typing import List, Optional, Sequence, Tuple, Type, Union, cast
from django.core.exceptions import FieldError
from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey
from mypy.newsemanal.typeanal import TypeAnalyser
from mypy.nodes import NameExpr, Expression
from mypy.checker import TypeChecker
from mypy.nodes import Expression, NameExpr
from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext
from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny
@@ -41,6 +40,7 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
if not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
return ret
api = cast(TypeChecker, ctx.api)
return helpers.reparametrize_instance(ret, [Instance(outer_model_info, [])])

View File

@@ -317,4 +317,21 @@
# class MyManager(models.Manager):
# pass
# class MyUser(MyBaseUser):
# objects = MyManager()
# objects = MyManager()
- case: custom_manager_returns_proper_model_types
main: |
from myapp.models import User
reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*'
reveal_type(User.objects.select_related()) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.User*, myapp.models.User*]'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyManager(models.Manager):
pass
class User(models.Model):
objects = MyManager()