diff --git a/django-stubs/contrib/auth/models.pyi b/django-stubs/contrib/auth/models.pyi index 9aab5ff..fcc4e20 100644 --- a/django-stubs/contrib/auth/models.pyi +++ b/django-stubs/contrib/auth/models.pyi @@ -10,20 +10,24 @@ from django.db import models def update_last_login(sender: Type[AbstractBaseUser], user: AbstractBaseUser, **kwargs: Any) -> None: ... -class PermissionManager(models.Manager): +class PermissionManager(models.Manager["Permission"]): def get_by_natural_key(self, codename: str, app_label: str, model: str) -> Permission: ... class Permission(models.Model): content_type_id: int + objects: PermissionManager + name = models.CharField(max_length=255) content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) codename = models.CharField(max_length=100) def natural_key(self) -> Tuple[str, str, str]: ... -class GroupManager(models.Manager): +class GroupManager(models.Manager["Group"]): def get_by_natural_key(self, name: str) -> Group: ... class Group(models.Model): + objects: GroupManager + name = models.CharField(max_length=150) permissions = models.ManyToManyField(Permission) def natural_key(self): ... @@ -40,8 +44,8 @@ class UserManager(BaseUserManager[_T]): class PermissionsMixin(models.Model): is_superuser = models.BooleanField() - groups: models.ManyToManyField = models.ManyToManyField(Group) - user_permissions: models.ManyToManyField = models.ManyToManyField(Permission) + groups = models.ManyToManyField(Group) + user_permissions = models.ManyToManyField(Permission) def get_group_permissions(self, obj: None = ...) -> Set[str]: ... def get_all_permissions(self, obj: Optional[str] = ...) -> Set[str]: ... def has_perm(self, perm: str, obj: Optional[str] = ...) -> bool: ... diff --git a/django-stubs/contrib/sites/models.pyi b/django-stubs/contrib/sites/models.pyi index 7b74d89..385c35d 100644 --- a/django-stubs/contrib/sites/models.pyi +++ b/django-stubs/contrib/sites/models.pyi @@ -6,15 +6,16 @@ from django.db import models SITE_CACHE: Any -class SiteManager(models.Manager): +class SiteManager(models.Manager["Site"]): def get_current(self, request: Optional[HttpRequest] = ...) -> Site: ... def clear_cache(self) -> None: ... def get_by_natural_key(self, domain: str) -> Site: ... class Site(models.Model): - domain: models.CharField = ... - name: models.CharField = ... - objects: SiteManager = ... + objects: SiteManager + + domain = models.CharField(max_length=100) + name = models.CharField(max_length=50) def natural_key(self) -> Tuple[str]: ... def clear_site_cache(sender: Type[Site], **kwargs: Any) -> None: ... diff --git a/django-stubs/db/models/manager.pyi b/django-stubs/db/models/manager.pyi index aef2baf..9ba593c 100644 --- a/django-stubs/db/models/manager.pyi +++ b/django-stubs/db/models/manager.pyi @@ -1,9 +1,10 @@ -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Iterable, Union from django.db.models.base import Model from django.db.models.query import QuerySet _T = TypeVar("_T", bound=Model, covariant=True) +_M = TypeVar("_M", bound="BaseManager") class BaseManager(QuerySet[_T]): creation_counter: int = ... @@ -20,13 +21,18 @@ class BaseManager(QuerySet[_T]): @classmethod def _get_queryset_methods(cls, queryset_class: type) -> Dict[str, Any]: ... def contribute_to_class(self, model: Type[Model], name: str) -> None: ... - def db_manager(self, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> Manager: ... + def db_manager(self: _M, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> _M: ... def get_queryset(self) -> QuerySet[_T]: ... class Manager(BaseManager[_T]): ... class RelatedManager(Manager[_T]): - def add(self, *objs: Model, bulk: bool = ...) -> None: ... + related_val: Tuple[int, ...] + def add(self, *objs: Union[_T, int], bulk: bool = ...) -> None: ... + def remove(self, *objs: Union[_T, int], bulk: bool = ...) -> None: ... + def set( + self, objs: Union[QuerySet[_T], Iterable[Union[_T, int]]], *, bulk: bool = ..., clear: bool = ... + ) -> None: ... def clear(self) -> None: ... class ManagerDescriptor: diff --git a/django-stubs/db/models/query.pyi b/django-stubs/db/models/query.pyi index 936413d..59d910b 100644 --- a/django-stubs/db/models/query.pyi +++ b/django-stubs/db/models/query.pyi @@ -52,8 +52,9 @@ class _BaseQuerySet(Generic[_T], Sized): def get(self, *args: Any, **kwargs: Any) -> _T: ... def create(self, *args: Any, **kwargs: Any) -> _T: ... def bulk_create( - self, objs: Iterable[Model], batch_size: Optional[int] = ..., ignore_conflicts: bool = ... + self, objs: Iterable[_T], batch_size: Optional[int] = ..., ignore_conflicts: bool = ... ) -> List[_T]: ... + def bulk_update(self, objs: Iterable[_T], fields: Sequence[str], batch_size: Optional[int] = ...) -> None: ... def get_or_create(self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any) -> Tuple[_T, bool]: ... def update_or_create( self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any @@ -119,8 +120,6 @@ class _BaseQuerySet(Generic[_T], Sized): @property def db(self) -> str: ... def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ... - # TODO: remove when django adds __class_getitem__ methods - def __getattr__(self, item: str) -> Any: ... class QuerySet(_BaseQuerySet[_T], Collection[_T], Sized): def __iter__(self) -> Iterator[_T]: ... diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 9689acb..07952e0 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -187,18 +187,12 @@ def add_new_class_for_module(module: MypyFile, name: str, bases: List[Instance], calculate_mro(new_typeinfo) new_typeinfo.calculate_metaclass_type() - def add_field_to_new_typeinfo(var: Var, is_initialized_in_class: bool = False, - is_property: bool = False) -> None: - var.info = new_typeinfo - var.is_initialized_in_class = is_initialized_in_class - var.is_property = is_property - var._fullname = new_typeinfo.fullname() + '.' + var.name() - new_typeinfo.names[var.name()] = SymbolTableNode(MDEF, var) - # add fields - var_items = [Var(item, typ) for item, typ in fields.items()] - for var_item in var_items: - add_field_to_new_typeinfo(var_item, is_property=True) + for field_name, field_type in fields.items(): + var = Var(field_name, type=field_type) + var.info = new_typeinfo + var._fullname = new_typeinfo.fullname() + '.' + field_name + new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True) classdef.info = new_typeinfo module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 195f565..aa95cdb 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -141,6 +141,7 @@ class AddManagers(ModelClassInitializer): has_manager_any_base = any(self._is_manager_any(base) for base in manager_info.bases) if has_manager_any_base: custom_model_manager_name = manager.model.__name__ + '_' + manager.__class__.__name__ + bases = [] for original_base in manager_info.bases: if self._is_manager_any(original_base): @@ -150,11 +151,22 @@ class AddManagers(ModelClassInitializer): original_base = helpers.reparametrize_instance(original_base, [Instance(self.model_classdef.info, [])]) bases.append(original_base) + current_module = self.api.modules[self.model_classdef.info.module_name] custom_manager_info = helpers.add_new_class_for_module(current_module, custom_model_manager_name, bases=bases, fields=OrderedDict()) + # copy fields to a new manager + for name, sym in manager_info.names.items(): + new_sym = sym.copy() + if isinstance(new_sym.node, Var): + new_var = Var(name, type=sym.type) + new_var.info = custom_manager_info + new_var._fullname = custom_manager_info.fullname() + '.' + name + new_sym.node = new_var + custom_manager_info.names[name] = new_sym + custom_manager_type = Instance(custom_manager_info, [Instance(self.model_classdef.info, [])]) self.add_new_node_to_model_class(manager_name, custom_manager_type) diff --git a/scripts/enabled_test_modules.py b/scripts/enabled_test_modules.py index cdfeb57..2bb30b9 100644 --- a/scripts/enabled_test_modules.py +++ b/scripts/enabled_test_modules.py @@ -94,7 +94,8 @@ IGNORED_ERRORS = { 'Unexpected keyword argument "unknown_kwarg" for "refresh_from_db" of "Model"', 'Unexpected attribute "foo" for model "Article"', 'has no attribute "touched"', - 'Incompatible types in assignment (expression has type "Type[CustomQuerySet]"' + 'Incompatible types in assignment (expression has type "Type[CustomQuerySet]"', + '"Manager[Article]" has no attribute "do_something"', ], 'backends': [ '"DatabaseError" has no attribute "pgcode"' @@ -102,6 +103,10 @@ IGNORED_ERRORS = { 'builtin_server': [ '"ServerHandler" has no attribute', ], + 'bulk_create': [ + 'has incompatible type "List[Country]"; expected "Iterable[TwoFields]"', + 'List item 1 has incompatible type "Country"; expected "ProxyCountry"', + ], 'check_framework': [ 'base class "Model" defined the type as "Callable', 'Value of type "Collection[str]" is not indexable', @@ -112,6 +117,8 @@ IGNORED_ERRORS = { ], 'contenttypes_tests': [ '"FooWithBrokenAbsoluteUrl" has no attribute "unknown_field"', + 'contenttypes_tests.models.Site', + 'Argument 1 to "set" of "RelatedManager" has incompatible type "SiteManager[Site]"', ], 'custom_lookups': [ 'in base class "SQLFuncMixin"', @@ -222,6 +229,7 @@ IGNORED_ERRORS = { ], 'm2m_regress': [ "Cannot resolve keyword 'porcupine' into field", + 'Argument 1 to "set" of "RelatedManager" has incompatible type "int"', ], 'messages_tests': [ 'List item 0 has incompatible type "Dict[str, Message]"; expected "Message"', @@ -233,12 +241,14 @@ IGNORED_ERRORS = { ], 'many_to_many': [ '(expression has type "List[Article]", variable has type "RelatedManager[Article]"', + '"add" of "RelatedManager" has incompatible type "Article"; expected "Union[Publication, int]"', ], 'many_to_one': [ 'Incompatible type for "parent" of "Child" (got "None", expected "Union[Parent, Combinable]")', 'Incompatible type for "parent" of "Child" (got "Child", expected "Union[Parent, Combinable]")', 'expression has type "List[]", variable has type "RelatedManager[Article]"', '"Reporter" has no attribute "cached_query"', + 'to "add" of "RelatedManager" has incompatible type "Reporter"; expected "Union[Article, int]"', ], 'middleware_exceptions': [ 'Argument 1 to "append" of "list" has incompatible type "Tuple[Any, Any]"; expected "str"' diff --git a/test-data/typecheck/managers/querysets/test_basic_methods.yml b/test-data/typecheck/managers/querysets/test_basic_methods.yml index 36dd6ee..55d4c4e 100644 --- a/test-data/typecheck/managers/querysets/test_basic_methods.yml +++ b/test-data/typecheck/managers/querysets/test_basic_methods.yml @@ -28,4 +28,20 @@ from django.db import models class Blog(models.Model): - created_at = models.DateTimeField() \ No newline at end of file + created_at = models.DateTimeField() + + +- case: queryset_missing_method + main: | + from myapp.models import User + reveal_type(User.objects) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.User]' + User.objects.not_existing_method() # E: "Manager[User]" has no attribute "not_existing_method" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class User(models.Model): + pass \ No newline at end of file diff --git a/test-data/typecheck/managers/test_managers.yml b/test-data/typecheck/managers/test_managers.yml index 5555952..939a782 100644 --- a/test-data/typecheck/managers/test_managers.yml +++ b/test-data/typecheck/managers/test_managers.yml @@ -309,6 +309,7 @@ from myapp.models import User reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*' reveal_type(User.objects.select_related()) # N: Revealed type is 'myapp.models.User_MyManager[myapp.models.User]' + reveal_type(User.objects.get_instance()) # N: Revealed type is 'builtins.int' installed_apps: - myapp files: @@ -317,6 +318,7 @@ content: | from django.db import models class MyManager(models.Manager): - pass + def get_instance(self) -> int: + pass class User(models.Model): objects = MyManager()