add support for managers as generics

This commit is contained in:
Maxim Kurnikov
2018-12-07 22:11:22 +03:00
parent 94ddb8c864
commit c9ad40d7e3
8 changed files with 237 additions and 187 deletions

View File

@@ -1,4 +1,4 @@
from typing import Type, Union, TypeVar, Any, Generic, List, Optional, Dict, Callable, Tuple, Sequence from typing import Type, Union, TypeVar, Any, Generic, List, Optional, Dict, Callable, Tuple, Sequence, TYPE_CHECKING
from uuid import UUID from uuid import UUID
from django.db import models from django.db import models
@@ -6,6 +6,9 @@ from django.db.models import Field, Model, QuerySet
from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.query_utils import PathInfo, Q from django.db.models.query_utils import PathInfo, Q
if TYPE_CHECKING:
from django.db.models.manager import RelatedManager
_T = TypeVar("_T", bound=models.Model) _T = TypeVar("_T", bound=models.Model)
class RelatedField(FieldCacheMixin, Field): class RelatedField(FieldCacheMixin, Field):
@@ -64,7 +67,7 @@ class ManyToManyField(RelatedField, Generic[_T]):
**kwargs: Any **kwargs: Any
) -> None: ... ) -> None: ...
def __set__(self, instance, value: Sequence[_T]) -> None: ... def __set__(self, instance, value: Sequence[_T]) -> None: ...
def __get__(self, instance, owner) -> QuerySet[_T]: ... def __get__(self, instance, owner) -> RelatedManager[_T]: ...
def check(self, **kwargs: Any) -> List[Any]: ... def check(self, **kwargs: Any) -> List[Any]: ...
def deconstruct(self) -> Tuple[Optional[str], str, List[Any], Dict[str, str]]: ... def deconstruct(self) -> Tuple[Optional[str], str, List[Any], Dict[str, str]]: ...
def get_path_info(self, filtered_relation: None = ...) -> List[PathInfo]: ... def get_path_info(self, filtered_relation: None = ...) -> List[PathInfo]: ...

View File

@@ -1,16 +1,11 @@
from collections import OrderedDict from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
from datetime import datetime
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple, Type, Union, TypeVar, Set, Generic, Iterator
from unittest.mock import MagicMock
from django.db.models import Q
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.query import QuerySet, RawQuerySet from django.db.models.query import QuerySet
_T = TypeVar("_T", bound=Model) _T = TypeVar("_T", bound=Model)
class BaseManager: class BaseManager(QuerySet[_T]):
creation_counter: int = ... creation_counter: int = ...
auto_created: bool = ... auto_created: bool = ...
use_in_migrations: bool = ... use_in_migrations: bool = ...
@@ -24,57 +19,13 @@ class BaseManager:
def from_queryset(cls, queryset_class: Any, class_name: Optional[Any] = ...): ... def from_queryset(cls, queryset_class: Any, class_name: Optional[Any] = ...): ...
def contribute_to_class(self, model: Type[Model], name: str) -> None: ... def contribute_to_class(self, model: Type[Model], name: str) -> None: ...
def db_manager(self, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> Manager: ... def db_manager(self, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> Manager: ...
@property def get_queryset(self) -> QuerySet[_T]: ...
def db(self) -> str: ...
def get_queryset(self) -> QuerySet: ...
def all(self) -> QuerySet: ...
def __eq__(self, other: Optional[Any]) -> bool: ...
def __hash__(self): ...
class Manager(Generic[_T]): class Manager(BaseManager[_T]): ...
def exists(self) -> bool: ...
def explain(self, *, format: Optional[Any] = ..., **options: Any) -> str: ... class RelatedManager(Manager[_T]):
def raw( def add(self, *objs: _T, bulk: bool = ...) -> None: ...
self, def clear(self) -> None: ...
raw_query: str,
params: Optional[Union[Dict[str, str], List[datetime], List[Decimal], List[str], Set[str], Tuple[int]]] = ...,
translations: Optional[Dict[str, str]] = ...,
using: None = ...,
) -> RawQuerySet: ...
def values(self, *fields: Any, **expressions: Any) -> QuerySet: ...
def values_list(self, *fields: Any, flat: bool = ..., named: bool = ...) -> QuerySet: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet: ...
def datetimes(self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...) -> QuerySet: ...
def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def complex_filter(
self, filter_obj: Union[Dict[str, datetime], Dict[str, QuerySet], Q, MagicMock]
) -> QuerySet[_T]: ...
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
def select_for_update(self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> QuerySet: ...
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
def extra(
self,
select: Optional[Union[Dict[str, int], Dict[str, str], OrderedDict]] = ...,
where: Optional[List[str]] = ...,
params: Optional[Union[List[int], List[str]]] = ...,
tables: Optional[List[str]] = ...,
order_by: Optional[Union[List[str], Tuple[str]]] = ...,
select_params: Optional[Union[List[int], List[str], Tuple[int]]] = ...,
) -> QuerySet[_T]: ...
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Optional[Union[datetime, float]]]: ...
def count(self) -> int: ...
def get(self, *args: Any, **kwargs: Any) -> _T: ...
def create(self, **kwargs: Any) -> _T: ...
class ManagerDescriptor: class ManagerDescriptor:
manager: Manager = ... manager: Manager = ...

View File

@@ -1,23 +1,9 @@
from collections import OrderedDict from collections import OrderedDict
from datetime import date, datetime from datetime import date, datetime
from typing import ( from typing import TypeVar, Optional, Any, Type, Dict, Union, overload, List, Iterator, Tuple, Callable, Iterable, Sized
TypeVar,
Optional,
Any,
Type,
Dict,
Union,
overload,
List,
Iterator,
Tuple,
Callable,
Iterable,
Sized,
Reversible,
)
from django.db import models from django.db import models
from django.db.models import Manager
_T = TypeVar("_T", bound=models.Model) _T = TypeVar("_T", bound=models.Model)
@@ -30,7 +16,7 @@ class QuerySet(Iterable[_T], Sized):
hints: Optional[Dict[str, models.Model]] = ..., hints: Optional[Dict[str, models.Model]] = ...,
) -> None: ... ) -> None: ...
@classmethod @classmethod
def as_manager(cls): ... def as_manager(cls) -> Manager[Any]: ...
def __len__(self) -> int: ... def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ... def __iter__(self) -> Iterator[_T]: ...
def __bool__(self) -> bool: ... def __bool__(self) -> bool: ...
@@ -110,6 +96,8 @@ class QuerySet(Iterable[_T], Sized):
@property @property
def db(self) -> str: ... def db(self) -> str: ...
def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ... def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ...
# TODO: remove when django adds __class_getitem__ methods
def __getattr__(self, item: str) -> Any: ...
class RawQuerySet: class RawQuerySet:
pass pass

View File

@@ -4,11 +4,23 @@ from typing import Dict, Optional
from mypy.nodes import StrExpr, MypyFile, TypeInfo, ImportedName, SymbolNode from mypy.nodes import StrExpr, MypyFile, TypeInfo, ImportedName, SymbolNode
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey' FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField' ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField'
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject' DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager'
MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager'
RELATED_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.RelatedManager'
MANAGER_CLASSES = {
MANAGER_CLASS_FULLNAME,
RELATED_MANAGER_CLASS_FULLNAME,
BASE_MANAGER_CLASS_FULLNAME,
QUERYSET_CLASS_FULLNAME
}
def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]: def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]:
models_module = '.'.join([app_name, 'models']) models_module = '.'.join([app_name, 'models'])

View File

@@ -1,24 +1,51 @@
import os import os
from typing import Callable, Optional from typing import Callable, Optional, cast
from mypy.checker import TypeChecker
from mypy.options import Options from mypy.options import Options
from mypy.plugin import Plugin, FunctionContext, ClassDefContext, AnalyzeTypeContext from mypy.plugin import Plugin, FunctionContext, ClassDefContext, AnalyzeTypeContext
from mypy.types import Type from mypy.types import Type, Instance
from mypy.typevars import fill_typevars
from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.plugins.fields import determine_type_of_array_field from mypy_django_plugin.plugins.fields import determine_type_of_array_field
from mypy_django_plugin.plugins.models import process_model_class from mypy_django_plugin.plugins.models import process_model_class
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
from mypy_django_plugin.plugins.settings import DjangoConfSettingsInitializerHook from mypy_django_plugin.plugins.settings import DjangoConfSettingsInitializerHook
base_model_classes = {helpers.MODEL_CLASS_FULLNAME} base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
manager_subclasses = set()
class TransformModelClassHook(object): def transform_model_class(ctx: ClassDefContext) -> None:
def __call__(self, ctx: ClassDefContext) -> None: base_model_classes.add(ctx.cls.fullname)
base_model_classes.add(ctx.cls.fullname) process_model_class(ctx)
process_model_class(ctx)
def add_new_manager_subclass(ctx: ClassDefContext) -> None:
manager_subclasses.add(ctx.cls.fullname)
def determine_proper_manager_type(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
ret = ctx.default_return_type
if not api.tscope.classes:
# not in class
return ret
outer_model_info = api.tscope.classes[0]
if not outer_model_info.has_base(helpers.MODEL_CLASS_FULLNAME):
return ret
if not isinstance(ret, Instance):
return ret
for i, base in enumerate(ret.type.bases):
if base.type.fullname() in {helpers.MANAGER_CLASS_FULLNAME,
helpers.RELATED_MANAGER_CLASS_FULLNAME,
helpers.BASE_MANAGER_CLASS_FULLNAME}:
ret.type.bases[i] = reparametrize_with(base, [Instance(outer_model_info, [])])
return ret
return ret
class DjangoPlugin(Plugin): class DjangoPlugin(Plugin):
@@ -39,21 +66,27 @@ class DjangoPlugin(Plugin):
def get_function_hook(self, fullname: str def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]: ) -> Optional[Callable[[FunctionContext], Type]]:
if fullname in {helpers.FOREIGN_KEY_FULLNAME, if fullname in {helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME}: helpers.ONETOONE_FIELD_FULLNAME,
helpers.MANYTOMANY_FIELD_FULLNAME}:
return extract_to_parameter_as_get_ret_type_for_related_field return extract_to_parameter_as_get_ret_type_for_related_field
if fullname == 'django.contrib.postgres.fields.array.ArrayField': if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field return determine_type_of_array_field
if fullname in manager_subclasses:
return determine_proper_manager_type
return None return None
def get_base_class_hook(self, fullname: str def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]: ) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in base_model_classes: if fullname in base_model_classes:
return TransformModelClassHook() return transform_model_class
if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS: if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS:
return DjangoConfSettingsInitializerHook(settings_module=self.django_settings) return DjangoConfSettingsInitializerHook(settings_module=self.django_settings)
if fullname in helpers.MANAGER_CLASSES:
return add_new_manager_subclass
return None return None

View File

@@ -1,3 +1,5 @@
import dataclasses
from abc import abstractmethod, ABCMeta
from typing import cast, Iterator, Tuple, Optional, Dict from typing import cast, Iterator, Tuple, Optional, Dict
from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \ from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \
@@ -9,6 +11,47 @@ from mypy.types import Instance
from mypy_django_plugin import helpers from mypy_django_plugin import helpers
@dataclasses.dataclass
class ModelClassInitializer(metaclass=ABCMeta):
api: SemanticAnalyzerPass2
model_classdef: ClassDef
@classmethod
def from_ctx(cls, ctx: ClassDefContext):
return cls(api=cast(SemanticAnalyzerPass2, ctx.api), model_classdef=ctx.cls)
def get_nested_meta_node(self) -> Optional[TypeInfo]:
metaclass_sym = self.model_classdef.info.names.get('Meta')
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
return metaclass_sym.node
return None
def is_abstract_model(self) -> bool:
meta_node = self.get_nested_meta_node()
if meta_node is None:
return False
for lvalue, rvalue in iter_over_assignments(meta_node.defn):
if isinstance(lvalue, NameExpr) and lvalue.name == 'abstract':
is_abstract = self.api.parse_bool(rvalue)
if is_abstract:
# abstract model do not need 'objects' queryset
return True
return False
def add_new_node_to_model_class(self, name: str, typ: Instance) -> None:
var = Var(name=name, type=typ)
var.info = typ.type
var._fullname = self.model_classdef.info.fullname() + '.' + name
var.is_inferred = True
var.is_initialized_in_class = True
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var)
@abstractmethod
def run(self) -> None:
raise NotImplementedError()
def add_new_var_node_to_class(class_type: TypeInfo, name: str, typ: Instance) -> None: def add_new_var_node_to_class(class_type: TypeInfo, name: str, typ: Instance) -> None:
var = Var(name=name, type=typ) var = Var(name=name, type=typ)
var.info = typ.type var.info = typ.type
@@ -37,78 +80,76 @@ def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]:
def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExpr, CallExpr]]: def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExpr, CallExpr]]:
for lvalue, rvalue in iter_call_assignments(klass): for lvalue, rvalue in iter_call_assignments(klass):
if (isinstance(lvalue, NameExpr) if (isinstance(lvalue, NameExpr)
and isinstance(rvalue.callee, MemberExpr)): and isinstance(rvalue.callee, MemberExpr)):
if rvalue.callee.fullname in {helpers.FOREIGN_KEY_FULLNAME, if rvalue.callee.fullname in {helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME}: helpers.ONETOONE_FIELD_FULLNAME}:
yield lvalue, rvalue yield lvalue, rvalue
def get_nested_meta_class(model_type: TypeInfo) -> Optional[TypeInfo]: class SetIdAttrsForRelatedFields(ModelClassInitializer):
metaclass_sym = model_type.names.get('Meta') def run(self) -> None:
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo): for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef):
return metaclass_sym.node self.add_new_node_to_model_class(lvalue.name + '_id',
return None typ=self.api.named_type('__builtins__.int'))
def is_abstract_model(ctx: ClassDefContext) -> bool: class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):
meta_node = get_nested_meta_class(ctx.cls.info) def run(self) -> None:
if meta_node is None: meta_node = self.get_nested_meta_node()
return False if meta_node is None:
return None
for lvalue, rvalue in iter_over_assignments(meta_node.defn): meta_node.fallback_to_any = True
if isinstance(lvalue, NameExpr) and lvalue.name == 'abstract':
is_abstract = ctx.api.parse_bool(rvalue)
if is_abstract:
# abstract model do not need 'objects' queryset
return True
return False
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None: class AddDefaultObjectsManager(ModelClassInitializer):
api = ctx.api def run(self) -> None:
for lvalue, rvalue in iter_over_one_to_n_related_fields(ctx.cls): if 'objects' in self.model_classdef.info.names:
property_name = lvalue.name + '_id' return None
add_new_var_node_to_class(ctx.cls.info, property_name,
typ=api.named_type('__builtins__.int')) if self.is_abstract_model():
# abstract models do not need 'objects' queryset
return None
typ = self.api.named_type_or_none(helpers.MANAGER_CLASS_FULLNAME,
args=[Instance(self.model_classdef.info, [])])
if not typ:
return None
self.add_new_node_to_model_class('objects', typ)
def add_int_id_attribute_if_primary_key_true_is_not_present(ctx: ClassDefContext) -> None: class AddIdAttributeIfPrimaryKeyTrueIsNotSet(ModelClassInitializer):
api = cast(SemanticAnalyzerPass2, ctx.api) def run(self) -> None:
if is_abstract_model(ctx): if self.is_abstract_model():
return None # no need for .id attr
return None
for _, rvalue in iter_call_assignments(ctx.cls): for _, rvalue in iter_call_assignments(self.model_classdef):
if ('primary_key' in rvalue.arg_names and if ('primary_key' in rvalue.arg_names
api.parse_bool(rvalue.args[rvalue.arg_names.index('primary_key')])): and self.api.parse_bool(rvalue.args[rvalue.arg_names.index('primary_key')])):
break break
else: else:
add_new_var_node_to_class(ctx.cls.info, 'id', api.builtin_type('builtins.int')) self.add_new_node_to_model_class('id', self.api.builtin_type('builtins.int'))
def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None: class AddRelatedManagers(ModelClassInitializer):
# search over mro def run(self) -> None:
objects_sym = ctx.cls.info.get('objects') for module_name, module_file in self.api.modules.items():
if objects_sym is not None: for defn in iter_over_classdefs(module_file):
return None for lvalue, rvalue in iter_call_assignments(defn):
if is_related_field(rvalue, module_file):
# only direct Meta class ref_to_fullname = extract_ref_to_fullname(rvalue,
if is_abstract_model(ctx): module_file=module_file,
# abstract model do not need 'objects' queryset all_modules=self.api.modules)
return None if self.model_classdef.fullname == ref_to_fullname:
if 'related_name' in rvalue.arg_names:
api = cast(SemanticAnalyzerPass2, ctx.api) related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')]
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, if not isinstance(related_name_expr, StrExpr):
args=[Instance(ctx.cls.info, [])]) return None
if not typ: related_name = related_name_expr.value
return None typ = get_related_field_type(rvalue, self.api, defn.info)
add_new_var_node_to_class(ctx.cls.info, 'objects', typ=typ) if typ is None:
return None
self.add_new_node_to_model_class(related_name, typ)
def inject_any_as_base_for_nested_class_meta(ctx: ClassDefContext) -> None:
meta_node = get_nested_meta_class(ctx.cls.info)
if meta_node is None:
return None
meta_node.fallback_to_any = True
def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]: def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]:
@@ -119,8 +160,8 @@ def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]:
def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2, def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2,
related_model_typ: TypeInfo) -> Optional[Instance]: related_model_typ: TypeInfo) -> Optional[Instance]:
if rvalue.callee.name == 'ForeignKey': if rvalue.callee.name in {'ForeignKey', 'ManyToManyField'}:
return api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, return api.named_type_or_none(helpers.RELATED_MANAGER_CLASS_FULLNAME,
args=[Instance(related_model_typ, [])]) args=[Instance(related_model_typ, [])])
else: else:
return Instance(related_model_typ, []) return Instance(related_model_typ, [])
@@ -129,7 +170,9 @@ def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2,
def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool: def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool:
if isinstance(expr.callee, MemberExpr) and isinstance(expr.callee.expr, NameExpr): if isinstance(expr.callee, MemberExpr) and isinstance(expr.callee.expr, NameExpr):
module = module_file.names[expr.callee.expr.name] module = module_file.names[expr.callee.expr.name]
if module.fullname == 'django.db.models' and expr.callee.name in {'ForeignKey', 'OneToOneField'}: if module.fullname == 'django.db.models' and expr.callee.name in {'ForeignKey',
'OneToOneField',
'ManyToManyField'}:
return True return True
return False return False
@@ -150,31 +193,16 @@ def extract_ref_to_fullname(rvalue_expr: CallExpr,
return None return None
def add_related_managers(ctx: ClassDefContext):
api = cast(SemanticAnalyzerPass2, ctx.api)
for module_name, module_file in ctx.api.modules.items():
for defn in iter_over_classdefs(module_file):
for lvalue, rvalue in iter_call_assignments(defn):
if is_related_field(rvalue, module_file):
ref_to_fullname = extract_ref_to_fullname(rvalue, module_file=module_file,
all_modules=api.modules)
if ctx.cls.fullname == ref_to_fullname:
if 'related_name' in rvalue.arg_names:
related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')]
if not isinstance(related_name_expr, StrExpr):
return None
related_name = related_name_expr.value
typ = get_related_field_type(rvalue, api, defn.info)
if typ is None:
return None
add_new_var_node_to_class(ctx.cls.info, related_name, typ)
def process_model_class(ctx: ClassDefContext) -> None: def process_model_class(ctx: ClassDefContext) -> None:
add_related_managers(ctx) initializers = [
inject_any_as_base_for_nested_class_meta(ctx) InjectAnyAsBaseForNestedMeta,
set_fieldname_attrs_for_related_fields(ctx) AddDefaultObjectsManager,
add_int_id_attribute_if_primary_key_true_is_not_present(ctx) AddIdAttributeIfPrimaryKeyTrueIsNotSet,
set_objects_queryset_to_model_class(ctx) SetIdAttrsForRelatedFields,
AddRelatedManagers
]
for initializer_cls in initializers:
initializer_cls.from_ctx(ctx).run()
# allow unspecified attributes for now
ctx.cls.info.fallback_to_any = True ctx.cls.info.fallback_to_any = True

View File

@@ -1,8 +1,21 @@
[CASE test_every_model_has_objects_queryset_available] [CASE test_every_model_has_objects_queryset_available]
from django.db import models from django.db import models
class User(models.Model): class User(models.Model):
pass pass
reveal_type(User.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.User]'
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
[CASE test_leave_as_is_if_objects_is_set_and_fill_typevars_with_outer_class]
from django.db import models
class UserManager(models.Manager):
def get_or_404(self) -> User:
pass
class User(models.Model):
objects = UserManager()
reveal_type(User.objects) # E: Revealed type is 'main.UserManager'
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
reveal_type(User.objects.get_or_404()) # E: Revealed type is 'main.User'
reveal_type(User.objects) # E: Revealed type is 'django.db.models.query.QuerySet[main.User]'
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'

View File

@@ -12,7 +12,7 @@ book = Book()
reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*' reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*'
publisher = Publisher() publisher = Publisher()
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]' reveal_type(publisher.books) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Book]'
[CASE test_foreign_key_field_creates_attribute_with_underscore_id] [CASE test_foreign_key_field_creates_attribute_with_underscore_id]
from django.db import models from django.db import models
@@ -46,8 +46,8 @@ reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*'
reveal_type(book.publisher2) # E: Revealed type is 'main.Publisher*' reveal_type(book.publisher2) # E: Revealed type is 'main.Publisher*'
publisher = Publisher() publisher = Publisher()
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]' reveal_type(publisher.books) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Book]'
reveal_type(publisher.books2) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]' reveal_type(publisher.books2) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Book]'
[CASE test_to_parameter_as_string_with_application_name__model_imported] [CASE test_to_parameter_as_string_with_application_name__model_imported]
from django.db import models from django.db import models
@@ -88,9 +88,9 @@ from django.db import models
class App(models.Model): class App(models.Model):
def method(self) -> None: def method(self) -> None:
reveal_type(self.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' reveal_type(self.views) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.View]'
reveal_type(self.members) # E: Revealed type is 'django.db.models.query.QuerySet[main.Member]' reveal_type(self.members) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Member]'
reveal_type(self.sheets) # E: Revealed type is 'django.db.models.query.QuerySet[main.Sheet]' reveal_type(self.sheets) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Sheet]'
reveal_type(self.profile) # E: Revealed type is 'main.Profile' reveal_type(self.profile) # E: Revealed type is 'main.Profile'
class View(models.Model): class View(models.Model):
app = models.ForeignKey(to=App, related_name='views', on_delete=models.CASCADE) app = models.ForeignKey(to=App, related_name='views', on_delete=models.CASCADE)
@@ -107,7 +107,7 @@ from myapp.models import App
class View(models.Model): class View(models.Model):
app = models.ForeignKey(to=App, related_name='views', on_delete=models.CASCADE) app = models.ForeignKey(to=App, related_name='views', on_delete=models.CASCADE)
reveal_type(View().app.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' reveal_type(View().app.views) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.View]'
reveal_type(View().app.unknown) # E: Revealed type is 'Any' reveal_type(View().app.unknown) # E: Revealed type is 'Any'
[out] [out]
@@ -116,7 +116,7 @@ reveal_type(View().app.unknown) # E: Revealed type is 'Any'
from django.db import models from django.db import models
class App(models.Model): class App(models.Model):
def method(self) -> None: def method(self) -> None:
reveal_type(self.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' reveal_type(self.views) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.View]'
[CASE models_related_managers_work_with_direct_model_inheritance_and_with_inheritance_from_other_model] [CASE models_related_managers_work_with_direct_model_inheritance_and_with_inheritance_from_other_model]
from django.db.models import Model from django.db.models import Model
@@ -131,8 +131,8 @@ class View(Model):
class View2(View): class View2(View):
app = models.ForeignKey(to=App, on_delete=models.CASCADE, related_name='views2') app = models.ForeignKey(to=App, on_delete=models.CASCADE, related_name='views2')
reveal_type(App().views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' reveal_type(App().views) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.View]'
reveal_type(App().views2) # E: Revealed type is 'django.db.models.query.QuerySet[main.View2]' reveal_type(App().views2) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.View2]'
[out] [out]
[CASE models_imported_inside_init_file] [CASE models_imported_inside_init_file]
@@ -140,7 +140,7 @@ from django.db import models
from myapp.models import App from myapp.models import App
class View(models.Model): class View(models.Model):
app = models.ForeignKey(to='myapp.App', related_name='views', on_delete=models.CASCADE) app = models.ForeignKey(to='myapp.App', related_name='views', on_delete=models.CASCADE)
reveal_type(View().app.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' reveal_type(View().app.views) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.View]'
[file myapp/__init__.py] [file myapp/__init__.py]
[file myapp/models/__init__.py] [file myapp/models/__init__.py]
@@ -191,4 +191,26 @@ from django.db import models
class App(models.Model): class App(models.Model):
owner = models.ForeignKey(to='myapp.User', on_delete=models.CASCADE, related_name='apps') owner = models.ForeignKey(to='myapp.User', on_delete=models.CASCADE, related_name='apps')
[CASE many_to_many_field_converts_to_queryset_of_model_type]
from django.db import models
class App(models.Model):
pass
class Member(models.Model):
apps = models.ManyToManyField(to=App, related_name='members')
reveal_type(Member().apps) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.App*]'
reveal_type(App().members) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Member]'
[out]
[CASE many_to_many_works_with_string_if_imported]
from django.db import models
from myapp.models import App
class Member(models.Model):
apps = models.ManyToManyField(to='myapp.App', related_name='members')
reveal_type(Member().apps) # E: Revealed type is 'django.db.models.manager.RelatedManager[myapp.models.App*]'
[file myapp/__init__.py]
[file myapp/models.py]
from django.db import models
class App(models.Model):
pass
[out]