remove catch-all __getattr__ for Manager, fix some issues with manager methods (#227)

This commit is contained in:
Maksim Kurnikov
2019-11-12 20:36:07 +03:00
committed by GitHub
parent e9a90ebff0
commit 8d986a0f43
9 changed files with 72 additions and 28 deletions

View File

@@ -10,20 +10,24 @@ from django.db import models
def update_last_login(sender: Type[AbstractBaseUser], user: AbstractBaseUser, **kwargs: Any) -> None: ...
class PermissionManager(models.Manager):
class PermissionManager(models.Manager["Permission"]):
def get_by_natural_key(self, codename: str, app_label: str, model: str) -> Permission: ...
class Permission(models.Model):
content_type_id: int
objects: PermissionManager
name = models.CharField(max_length=255)
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
codename = models.CharField(max_length=100)
def natural_key(self) -> Tuple[str, str, str]: ...
class GroupManager(models.Manager):
class GroupManager(models.Manager["Group"]):
def get_by_natural_key(self, name: str) -> Group: ...
class Group(models.Model):
objects: GroupManager
name = models.CharField(max_length=150)
permissions = models.ManyToManyField(Permission)
def natural_key(self): ...
@@ -40,8 +44,8 @@ class UserManager(BaseUserManager[_T]):
class PermissionsMixin(models.Model):
is_superuser = models.BooleanField()
groups: models.ManyToManyField = models.ManyToManyField(Group)
user_permissions: models.ManyToManyField = models.ManyToManyField(Permission)
groups = models.ManyToManyField(Group)
user_permissions = models.ManyToManyField(Permission)
def get_group_permissions(self, obj: None = ...) -> Set[str]: ...
def get_all_permissions(self, obj: Optional[str] = ...) -> Set[str]: ...
def has_perm(self, perm: str, obj: Optional[str] = ...) -> bool: ...

View File

@@ -6,15 +6,16 @@ from django.db import models
SITE_CACHE: Any
class SiteManager(models.Manager):
class SiteManager(models.Manager["Site"]):
def get_current(self, request: Optional[HttpRequest] = ...) -> Site: ...
def clear_cache(self) -> None: ...
def get_by_natural_key(self, domain: str) -> Site: ...
class Site(models.Model):
domain: models.CharField = ...
name: models.CharField = ...
objects: SiteManager = ...
objects: SiteManager
domain = models.CharField(max_length=100)
name = models.CharField(max_length=50)
def natural_key(self) -> Tuple[str]: ...
def clear_site_cache(sender: Type[Site], **kwargs: Any) -> None: ...

View File

@@ -1,9 +1,10 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Iterable, Union
from django.db.models.base import Model
from django.db.models.query import QuerySet
_T = TypeVar("_T", bound=Model, covariant=True)
_M = TypeVar("_M", bound="BaseManager")
class BaseManager(QuerySet[_T]):
creation_counter: int = ...
@@ -20,13 +21,18 @@ class BaseManager(QuerySet[_T]):
@classmethod
def _get_queryset_methods(cls, queryset_class: type) -> Dict[str, Any]: ...
def contribute_to_class(self, model: Type[Model], name: str) -> None: ...
def db_manager(self, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> Manager: ...
def db_manager(self: _M, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> _M: ...
def get_queryset(self) -> QuerySet[_T]: ...
class Manager(BaseManager[_T]): ...
class RelatedManager(Manager[_T]):
def add(self, *objs: Model, bulk: bool = ...) -> None: ...
related_val: Tuple[int, ...]
def add(self, *objs: Union[_T, int], bulk: bool = ...) -> None: ...
def remove(self, *objs: Union[_T, int], bulk: bool = ...) -> None: ...
def set(
self, objs: Union[QuerySet[_T], Iterable[Union[_T, int]]], *, bulk: bool = ..., clear: bool = ...
) -> None: ...
def clear(self) -> None: ...
class ManagerDescriptor:

View File

@@ -52,8 +52,9 @@ class _BaseQuerySet(Generic[_T], Sized):
def get(self, *args: Any, **kwargs: Any) -> _T: ...
def create(self, *args: Any, **kwargs: Any) -> _T: ...
def bulk_create(
self, objs: Iterable[Model], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
self, objs: Iterable[_T], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
) -> List[_T]: ...
def bulk_update(self, objs: Iterable[_T], fields: Sequence[str], batch_size: Optional[int] = ...) -> None: ...
def get_or_create(self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any) -> Tuple[_T, bool]: ...
def update_or_create(
self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any
@@ -119,8 +120,6 @@ class _BaseQuerySet(Generic[_T], Sized):
@property
def db(self) -> str: ...
def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ...
# TODO: remove when django adds __class_getitem__ methods
def __getattr__(self, item: str) -> Any: ...
class QuerySet(_BaseQuerySet[_T], Collection[_T], Sized):
def __iter__(self) -> Iterator[_T]: ...

View File

@@ -187,18 +187,12 @@ def add_new_class_for_module(module: MypyFile, name: str, bases: List[Instance],
calculate_mro(new_typeinfo)
new_typeinfo.calculate_metaclass_type()
def add_field_to_new_typeinfo(var: Var, is_initialized_in_class: bool = False,
is_property: bool = False) -> None:
var.info = new_typeinfo
var.is_initialized_in_class = is_initialized_in_class
var.is_property = is_property
var._fullname = new_typeinfo.fullname() + '.' + var.name()
new_typeinfo.names[var.name()] = SymbolTableNode(MDEF, var)
# add fields
var_items = [Var(item, typ) for item, typ in fields.items()]
for var_item in var_items:
add_field_to_new_typeinfo(var_item, is_property=True)
for field_name, field_type in fields.items():
var = Var(field_name, type=field_type)
var.info = new_typeinfo
var._fullname = new_typeinfo.fullname() + '.' + field_name
new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True)
classdef.info = new_typeinfo
module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)

View File

@@ -141,6 +141,7 @@ class AddManagers(ModelClassInitializer):
has_manager_any_base = any(self._is_manager_any(base) for base in manager_info.bases)
if has_manager_any_base:
custom_model_manager_name = manager.model.__name__ + '_' + manager.__class__.__name__
bases = []
for original_base in manager_info.bases:
if self._is_manager_any(original_base):
@@ -150,11 +151,22 @@ class AddManagers(ModelClassInitializer):
original_base = helpers.reparametrize_instance(original_base,
[Instance(self.model_classdef.info, [])])
bases.append(original_base)
current_module = self.api.modules[self.model_classdef.info.module_name]
custom_manager_info = helpers.add_new_class_for_module(current_module,
custom_model_manager_name,
bases=bases,
fields=OrderedDict())
# copy fields to a new manager
for name, sym in manager_info.names.items():
new_sym = sym.copy()
if isinstance(new_sym.node, Var):
new_var = Var(name, type=sym.type)
new_var.info = custom_manager_info
new_var._fullname = custom_manager_info.fullname() + '.' + name
new_sym.node = new_var
custom_manager_info.names[name] = new_sym
custom_manager_type = Instance(custom_manager_info, [Instance(self.model_classdef.info, [])])
self.add_new_node_to_model_class(manager_name, custom_manager_type)

View File

@@ -94,7 +94,8 @@ IGNORED_ERRORS = {
'Unexpected keyword argument "unknown_kwarg" for "refresh_from_db" of "Model"',
'Unexpected attribute "foo" for model "Article"',
'has no attribute "touched"',
'Incompatible types in assignment (expression has type "Type[CustomQuerySet]"'
'Incompatible types in assignment (expression has type "Type[CustomQuerySet]"',
'"Manager[Article]" has no attribute "do_something"',
],
'backends': [
'"DatabaseError" has no attribute "pgcode"'
@@ -102,6 +103,10 @@ IGNORED_ERRORS = {
'builtin_server': [
'"ServerHandler" has no attribute',
],
'bulk_create': [
'has incompatible type "List[Country]"; expected "Iterable[TwoFields]"',
'List item 1 has incompatible type "Country"; expected "ProxyCountry"',
],
'check_framework': [
'base class "Model" defined the type as "Callable',
'Value of type "Collection[str]" is not indexable',
@@ -112,6 +117,8 @@ IGNORED_ERRORS = {
],
'contenttypes_tests': [
'"FooWithBrokenAbsoluteUrl" has no attribute "unknown_field"',
'contenttypes_tests.models.Site',
'Argument 1 to "set" of "RelatedManager" has incompatible type "SiteManager[Site]"',
],
'custom_lookups': [
'in base class "SQLFuncMixin"',
@@ -222,6 +229,7 @@ IGNORED_ERRORS = {
],
'm2m_regress': [
"Cannot resolve keyword 'porcupine' into field",
'Argument 1 to "set" of "RelatedManager" has incompatible type "int"',
],
'messages_tests': [
'List item 0 has incompatible type "Dict[str, Message]"; expected "Message"',
@@ -233,12 +241,14 @@ IGNORED_ERRORS = {
],
'many_to_many': [
'(expression has type "List[Article]", variable has type "RelatedManager[Article]"',
'"add" of "RelatedManager" has incompatible type "Article"; expected "Union[Publication, int]"',
],
'many_to_one': [
'Incompatible type for "parent" of "Child" (got "None", expected "Union[Parent, Combinable]")',
'Incompatible type for "parent" of "Child" (got "Child", expected "Union[Parent, Combinable]")',
'expression has type "List[<nothing>]", variable has type "RelatedManager[Article]"',
'"Reporter" has no attribute "cached_query"',
'to "add" of "RelatedManager" has incompatible type "Reporter"; expected "Union[Article, int]"',
],
'middleware_exceptions': [
'Argument 1 to "append" of "list" has incompatible type "Tuple[Any, Any]"; expected "str"'

View File

@@ -29,3 +29,19 @@
class Blog(models.Model):
created_at = models.DateTimeField()
- case: queryset_missing_method
main: |
from myapp.models import User
reveal_type(User.objects) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.User]'
User.objects.not_existing_method() # E: "Manager[User]" has no attribute "not_existing_method"
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class User(models.Model):
pass

View File

@@ -309,6 +309,7 @@
from myapp.models import User
reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*'
reveal_type(User.objects.select_related()) # N: Revealed type is 'myapp.models.User_MyManager[myapp.models.User]'
reveal_type(User.objects.get_instance()) # N: Revealed type is 'builtins.int'
installed_apps:
- myapp
files:
@@ -317,6 +318,7 @@
content: |
from django.db import models
class MyManager(models.Manager):
def get_instance(self) -> int:
pass
class User(models.Model):
objects = MyManager()