remove catch-all __getattr__ for Manager, fix some issues with manager methods (#227)

This commit is contained in:
Maksim Kurnikov
2019-11-12 20:36:07 +03:00
committed by GitHub
parent e9a90ebff0
commit 8d986a0f43
9 changed files with 72 additions and 28 deletions

View File

@@ -10,20 +10,24 @@ from django.db import models
def update_last_login(sender: Type[AbstractBaseUser], user: AbstractBaseUser, **kwargs: Any) -> None: ... def update_last_login(sender: Type[AbstractBaseUser], user: AbstractBaseUser, **kwargs: Any) -> None: ...
class PermissionManager(models.Manager): class PermissionManager(models.Manager["Permission"]):
def get_by_natural_key(self, codename: str, app_label: str, model: str) -> Permission: ... def get_by_natural_key(self, codename: str, app_label: str, model: str) -> Permission: ...
class Permission(models.Model): class Permission(models.Model):
content_type_id: int content_type_id: int
objects: PermissionManager
name = models.CharField(max_length=255) name = models.CharField(max_length=255)
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
codename = models.CharField(max_length=100) codename = models.CharField(max_length=100)
def natural_key(self) -> Tuple[str, str, str]: ... def natural_key(self) -> Tuple[str, str, str]: ...
class GroupManager(models.Manager): class GroupManager(models.Manager["Group"]):
def get_by_natural_key(self, name: str) -> Group: ... def get_by_natural_key(self, name: str) -> Group: ...
class Group(models.Model): class Group(models.Model):
objects: GroupManager
name = models.CharField(max_length=150) name = models.CharField(max_length=150)
permissions = models.ManyToManyField(Permission) permissions = models.ManyToManyField(Permission)
def natural_key(self): ... def natural_key(self): ...
@@ -40,8 +44,8 @@ class UserManager(BaseUserManager[_T]):
class PermissionsMixin(models.Model): class PermissionsMixin(models.Model):
is_superuser = models.BooleanField() is_superuser = models.BooleanField()
groups: models.ManyToManyField = models.ManyToManyField(Group) groups = models.ManyToManyField(Group)
user_permissions: models.ManyToManyField = models.ManyToManyField(Permission) user_permissions = models.ManyToManyField(Permission)
def get_group_permissions(self, obj: None = ...) -> Set[str]: ... def get_group_permissions(self, obj: None = ...) -> Set[str]: ...
def get_all_permissions(self, obj: Optional[str] = ...) -> Set[str]: ... def get_all_permissions(self, obj: Optional[str] = ...) -> Set[str]: ...
def has_perm(self, perm: str, obj: Optional[str] = ...) -> bool: ... def has_perm(self, perm: str, obj: Optional[str] = ...) -> bool: ...

View File

@@ -6,15 +6,16 @@ from django.db import models
SITE_CACHE: Any SITE_CACHE: Any
class SiteManager(models.Manager): class SiteManager(models.Manager["Site"]):
def get_current(self, request: Optional[HttpRequest] = ...) -> Site: ... def get_current(self, request: Optional[HttpRequest] = ...) -> Site: ...
def clear_cache(self) -> None: ... def clear_cache(self) -> None: ...
def get_by_natural_key(self, domain: str) -> Site: ... def get_by_natural_key(self, domain: str) -> Site: ...
class Site(models.Model): class Site(models.Model):
domain: models.CharField = ... objects: SiteManager
name: models.CharField = ...
objects: SiteManager = ... domain = models.CharField(max_length=100)
name = models.CharField(max_length=50)
def natural_key(self) -> Tuple[str]: ... def natural_key(self) -> Tuple[str]: ...
def clear_site_cache(sender: Type[Site], **kwargs: Any) -> None: ... def clear_site_cache(sender: Type[Site], **kwargs: Any) -> None: ...

View File

@@ -1,9 +1,10 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Iterable, Union
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
_T = TypeVar("_T", bound=Model, covariant=True) _T = TypeVar("_T", bound=Model, covariant=True)
_M = TypeVar("_M", bound="BaseManager")
class BaseManager(QuerySet[_T]): class BaseManager(QuerySet[_T]):
creation_counter: int = ... creation_counter: int = ...
@@ -20,13 +21,18 @@ class BaseManager(QuerySet[_T]):
@classmethod @classmethod
def _get_queryset_methods(cls, queryset_class: type) -> Dict[str, Any]: ... def _get_queryset_methods(cls, queryset_class: type) -> Dict[str, Any]: ...
def contribute_to_class(self, model: Type[Model], name: str) -> None: ... def contribute_to_class(self, model: Type[Model], name: str) -> None: ...
def db_manager(self, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> Manager: ... def db_manager(self: _M, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> _M: ...
def get_queryset(self) -> QuerySet[_T]: ... def get_queryset(self) -> QuerySet[_T]: ...
class Manager(BaseManager[_T]): ... class Manager(BaseManager[_T]): ...
class RelatedManager(Manager[_T]): class RelatedManager(Manager[_T]):
def add(self, *objs: Model, bulk: bool = ...) -> None: ... related_val: Tuple[int, ...]
def add(self, *objs: Union[_T, int], bulk: bool = ...) -> None: ...
def remove(self, *objs: Union[_T, int], bulk: bool = ...) -> None: ...
def set(
self, objs: Union[QuerySet[_T], Iterable[Union[_T, int]]], *, bulk: bool = ..., clear: bool = ...
) -> None: ...
def clear(self) -> None: ... def clear(self) -> None: ...
class ManagerDescriptor: class ManagerDescriptor:

View File

@@ -52,8 +52,9 @@ class _BaseQuerySet(Generic[_T], Sized):
def get(self, *args: Any, **kwargs: Any) -> _T: ... def get(self, *args: Any, **kwargs: Any) -> _T: ...
def create(self, *args: Any, **kwargs: Any) -> _T: ... def create(self, *args: Any, **kwargs: Any) -> _T: ...
def bulk_create( def bulk_create(
self, objs: Iterable[Model], batch_size: Optional[int] = ..., ignore_conflicts: bool = ... self, objs: Iterable[_T], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
) -> List[_T]: ... ) -> List[_T]: ...
def bulk_update(self, objs: Iterable[_T], fields: Sequence[str], batch_size: Optional[int] = ...) -> None: ...
def get_or_create(self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any) -> Tuple[_T, bool]: ... def get_or_create(self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any) -> Tuple[_T, bool]: ...
def update_or_create( def update_or_create(
self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any
@@ -119,8 +120,6 @@ class _BaseQuerySet(Generic[_T], Sized):
@property @property
def db(self) -> str: ... def db(self) -> str: ...
def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ... def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ...
# TODO: remove when django adds __class_getitem__ methods
def __getattr__(self, item: str) -> Any: ...
class QuerySet(_BaseQuerySet[_T], Collection[_T], Sized): class QuerySet(_BaseQuerySet[_T], Collection[_T], Sized):
def __iter__(self) -> Iterator[_T]: ... def __iter__(self) -> Iterator[_T]: ...

View File

@@ -187,18 +187,12 @@ def add_new_class_for_module(module: MypyFile, name: str, bases: List[Instance],
calculate_mro(new_typeinfo) calculate_mro(new_typeinfo)
new_typeinfo.calculate_metaclass_type() new_typeinfo.calculate_metaclass_type()
def add_field_to_new_typeinfo(var: Var, is_initialized_in_class: bool = False,
is_property: bool = False) -> None:
var.info = new_typeinfo
var.is_initialized_in_class = is_initialized_in_class
var.is_property = is_property
var._fullname = new_typeinfo.fullname() + '.' + var.name()
new_typeinfo.names[var.name()] = SymbolTableNode(MDEF, var)
# add fields # add fields
var_items = [Var(item, typ) for item, typ in fields.items()] for field_name, field_type in fields.items():
for var_item in var_items: var = Var(field_name, type=field_type)
add_field_to_new_typeinfo(var_item, is_property=True) var.info = new_typeinfo
var._fullname = new_typeinfo.fullname() + '.' + field_name
new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True)
classdef.info = new_typeinfo classdef.info = new_typeinfo
module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True) module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)

View File

@@ -141,6 +141,7 @@ class AddManagers(ModelClassInitializer):
has_manager_any_base = any(self._is_manager_any(base) for base in manager_info.bases) has_manager_any_base = any(self._is_manager_any(base) for base in manager_info.bases)
if has_manager_any_base: if has_manager_any_base:
custom_model_manager_name = manager.model.__name__ + '_' + manager.__class__.__name__ custom_model_manager_name = manager.model.__name__ + '_' + manager.__class__.__name__
bases = [] bases = []
for original_base in manager_info.bases: for original_base in manager_info.bases:
if self._is_manager_any(original_base): if self._is_manager_any(original_base):
@@ -150,11 +151,22 @@ class AddManagers(ModelClassInitializer):
original_base = helpers.reparametrize_instance(original_base, original_base = helpers.reparametrize_instance(original_base,
[Instance(self.model_classdef.info, [])]) [Instance(self.model_classdef.info, [])])
bases.append(original_base) bases.append(original_base)
current_module = self.api.modules[self.model_classdef.info.module_name] current_module = self.api.modules[self.model_classdef.info.module_name]
custom_manager_info = helpers.add_new_class_for_module(current_module, custom_manager_info = helpers.add_new_class_for_module(current_module,
custom_model_manager_name, custom_model_manager_name,
bases=bases, bases=bases,
fields=OrderedDict()) fields=OrderedDict())
# copy fields to a new manager
for name, sym in manager_info.names.items():
new_sym = sym.copy()
if isinstance(new_sym.node, Var):
new_var = Var(name, type=sym.type)
new_var.info = custom_manager_info
new_var._fullname = custom_manager_info.fullname() + '.' + name
new_sym.node = new_var
custom_manager_info.names[name] = new_sym
custom_manager_type = Instance(custom_manager_info, [Instance(self.model_classdef.info, [])]) custom_manager_type = Instance(custom_manager_info, [Instance(self.model_classdef.info, [])])
self.add_new_node_to_model_class(manager_name, custom_manager_type) self.add_new_node_to_model_class(manager_name, custom_manager_type)

View File

@@ -94,7 +94,8 @@ IGNORED_ERRORS = {
'Unexpected keyword argument "unknown_kwarg" for "refresh_from_db" of "Model"', 'Unexpected keyword argument "unknown_kwarg" for "refresh_from_db" of "Model"',
'Unexpected attribute "foo" for model "Article"', 'Unexpected attribute "foo" for model "Article"',
'has no attribute "touched"', 'has no attribute "touched"',
'Incompatible types in assignment (expression has type "Type[CustomQuerySet]"' 'Incompatible types in assignment (expression has type "Type[CustomQuerySet]"',
'"Manager[Article]" has no attribute "do_something"',
], ],
'backends': [ 'backends': [
'"DatabaseError" has no attribute "pgcode"' '"DatabaseError" has no attribute "pgcode"'
@@ -102,6 +103,10 @@ IGNORED_ERRORS = {
'builtin_server': [ 'builtin_server': [
'"ServerHandler" has no attribute', '"ServerHandler" has no attribute',
], ],
'bulk_create': [
'has incompatible type "List[Country]"; expected "Iterable[TwoFields]"',
'List item 1 has incompatible type "Country"; expected "ProxyCountry"',
],
'check_framework': [ 'check_framework': [
'base class "Model" defined the type as "Callable', 'base class "Model" defined the type as "Callable',
'Value of type "Collection[str]" is not indexable', 'Value of type "Collection[str]" is not indexable',
@@ -112,6 +117,8 @@ IGNORED_ERRORS = {
], ],
'contenttypes_tests': [ 'contenttypes_tests': [
'"FooWithBrokenAbsoluteUrl" has no attribute "unknown_field"', '"FooWithBrokenAbsoluteUrl" has no attribute "unknown_field"',
'contenttypes_tests.models.Site',
'Argument 1 to "set" of "RelatedManager" has incompatible type "SiteManager[Site]"',
], ],
'custom_lookups': [ 'custom_lookups': [
'in base class "SQLFuncMixin"', 'in base class "SQLFuncMixin"',
@@ -222,6 +229,7 @@ IGNORED_ERRORS = {
], ],
'm2m_regress': [ 'm2m_regress': [
"Cannot resolve keyword 'porcupine' into field", "Cannot resolve keyword 'porcupine' into field",
'Argument 1 to "set" of "RelatedManager" has incompatible type "int"',
], ],
'messages_tests': [ 'messages_tests': [
'List item 0 has incompatible type "Dict[str, Message]"; expected "Message"', 'List item 0 has incompatible type "Dict[str, Message]"; expected "Message"',
@@ -233,12 +241,14 @@ IGNORED_ERRORS = {
], ],
'many_to_many': [ 'many_to_many': [
'(expression has type "List[Article]", variable has type "RelatedManager[Article]"', '(expression has type "List[Article]", variable has type "RelatedManager[Article]"',
'"add" of "RelatedManager" has incompatible type "Article"; expected "Union[Publication, int]"',
], ],
'many_to_one': [ 'many_to_one': [
'Incompatible type for "parent" of "Child" (got "None", expected "Union[Parent, Combinable]")', 'Incompatible type for "parent" of "Child" (got "None", expected "Union[Parent, Combinable]")',
'Incompatible type for "parent" of "Child" (got "Child", expected "Union[Parent, Combinable]")', 'Incompatible type for "parent" of "Child" (got "Child", expected "Union[Parent, Combinable]")',
'expression has type "List[<nothing>]", variable has type "RelatedManager[Article]"', 'expression has type "List[<nothing>]", variable has type "RelatedManager[Article]"',
'"Reporter" has no attribute "cached_query"', '"Reporter" has no attribute "cached_query"',
'to "add" of "RelatedManager" has incompatible type "Reporter"; expected "Union[Article, int]"',
], ],
'middleware_exceptions': [ 'middleware_exceptions': [
'Argument 1 to "append" of "list" has incompatible type "Tuple[Any, Any]"; expected "str"' 'Argument 1 to "append" of "list" has incompatible type "Tuple[Any, Any]"; expected "str"'

View File

@@ -28,4 +28,20 @@
from django.db import models from django.db import models
class Blog(models.Model): class Blog(models.Model):
created_at = models.DateTimeField() created_at = models.DateTimeField()
- case: queryset_missing_method
main: |
from myapp.models import User
reveal_type(User.objects) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.User]'
User.objects.not_existing_method() # E: "Manager[User]" has no attribute "not_existing_method"
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class User(models.Model):
pass

View File

@@ -309,6 +309,7 @@
from myapp.models import User from myapp.models import User
reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*' reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*'
reveal_type(User.objects.select_related()) # N: Revealed type is 'myapp.models.User_MyManager[myapp.models.User]' reveal_type(User.objects.select_related()) # N: Revealed type is 'myapp.models.User_MyManager[myapp.models.User]'
reveal_type(User.objects.get_instance()) # N: Revealed type is 'builtins.int'
installed_apps: installed_apps:
- myapp - myapp
files: files:
@@ -317,6 +318,7 @@
content: | content: |
from django.db import models from django.db import models
class MyManager(models.Manager): class MyManager(models.Manager):
pass def get_instance(self) -> int:
pass
class User(models.Model): class User(models.Model):
objects = MyManager() objects = MyManager()