From 09767210ecec25aa05e9c9a0994207167ceeede8 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Fri, 23 Aug 2019 04:03:03 +0300 Subject: [PATCH] Return corresponding descriptors for some fields in class access (#137) * return corresponding descriptors for some related fields in class access * return corresponding descriptors for file fields in class access * fix tests --- django-stubs/db/models/fields/files.pyi | 23 ++++++++++- django-stubs/db/models/fields/related.pyi | 46 +++++++++++++++++++-- test-data/typecheck/fields/test_related.yml | 26 +++++++++++- 3 files changed, 89 insertions(+), 6 deletions(-) diff --git a/django-stubs/db/models/fields/files.pyi b/django-stubs/db/models/fields/files.pyi index 945be31..311e601 100644 --- a/django-stubs/db/models/fields/files.pyi +++ b/django-stubs/db/models/fields/files.pyi @@ -1,6 +1,5 @@ -from typing import Any, Callable, List, Optional, Type, Union, Tuple, Iterable +from typing import Any, Callable, List, Optional, Type, Union, Tuple, Iterable, overload, TypeVar -from django.core.checks.messages import Error from django.core.files.base import File from django.core.files.images import ImageFile from django.core.files.storage import FileSystemStorage, Storage @@ -34,6 +33,8 @@ class FileDescriptor: def __set__(self, instance: Model, value: Optional[Any]) -> None: ... def __get__(self, instance: Optional[Model], cls: Type[Model] = ...) -> Union[FieldFile, FileDescriptor]: ... +_T = TypeVar("_T", bound="Field") + class FileField(Field): storage: Any = ... upload_to: Union[str, Callable] = ... @@ -63,6 +64,15 @@ class FileField(Field): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... + # class access + @overload # type: ignore + def __get__(self, instance: None, owner) -> FileDescriptor: ... + # Model instance access + @overload + def __get__(self, instance: Model, owner) -> Any: ... + # non-Model instances + @overload + def __get__(self: _T, instance, owner) -> _T: ... def generate_filename(self, instance: Optional[Model], filename: str) -> str: ... class ImageFileDescriptor(FileDescriptor): @@ -82,4 +92,13 @@ class ImageField(FileField): height_field: Optional[str] = ..., **kwargs: Any ) -> None: ... + # class access + @overload # type: ignore + def __get__(self, instance: None, owner) -> ImageFileDescriptor: ... + # Model instance access + @overload + def __get__(self, instance: Model, owner) -> Any: ... + # non-Model instances + @overload + def __get__(self: _T, instance, owner) -> _T: ... def update_dimension_fields(self, instance: Model, force: bool = ..., *args: Any, **kwargs: Any) -> None: ... diff --git a/django-stubs/db/models/fields/related.pyi b/django-stubs/db/models/fields/related.pyi index 6192cd3..74ab259 100644 --- a/django-stubs/db/models/fields/related.pyi +++ b/django-stubs/db/models/fields/related.pyi @@ -1,4 +1,18 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, Type, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + TYPE_CHECKING, + Tuple, + Type, + TypeVar, + Union, + overload, +) from uuid import UUID from django.db.models.expressions import Combinable @@ -25,7 +39,7 @@ if TYPE_CHECKING: from django.db.models.manager import RelatedManager _T = TypeVar("_T", bound=models.Model) - +_F = TypeVar("_F", bound=models.Field) _Choice = Tuple[Any, str] _ChoiceNamedGroup = Tuple[str, Iterable[_Choice]] _FieldChoices = Iterable[Union[_Choice, _ChoiceNamedGroup]] @@ -127,6 +141,15 @@ class ForeignKey(ForeignObject[_ST, _GT]): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... + # class access + @overload # type: ignore + def __get__(self, instance: None, owner) -> ForwardManyToOneDescriptor: ... + # Model instance access + @overload + def __get__(self, instance: Model, owner) -> _GT: ... + # non-Model instances + @overload + def __get__(self: _F, instance, owner) -> _F: ... class OneToOneField(RelatedField[_ST, _GT]): _pyi_private_set_type: Union[Any, Combinable] @@ -163,6 +186,15 @@ class OneToOneField(RelatedField[_ST, _GT]): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... + # class access + @overload # type: ignore + def __get__(self, instance: None, owner) -> ForwardOneToOneDescriptor: ... + # Model instance access + @overload + def __get__(self, instance: Model, owner) -> _GT: ... + # non-Model instances + @overload + def __get__(self: _F, instance, owner) -> _F: ... class ManyToManyField(RelatedField[_ST, _GT]): _pyi_private_set_type: Sequence[Any] @@ -206,7 +238,15 @@ class ManyToManyField(RelatedField[_ST, _GT]): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ) -> None: ... - def __get__(self, instance, owner) -> _GT: ... # type: ignore + # class access + @overload # type: ignore + def __get__(self, instance: None, owner) -> ManyToManyDescriptor: ... + # Model instance access + @overload + def __get__(self, instance: Model, owner) -> _GT: ... + # non-Model instances + @overload + def __get__(self: _F, instance, owner) -> _F: ... def get_path_info(self, filtered_relation: None = ...) -> List[PathInfo]: ... def get_reverse_path_info(self, filtered_relation: None = ...) -> List[PathInfo]: ... def contribute_to_related_class(self, cls: Type[Model], related: RelatedField) -> None: ... diff --git a/test-data/typecheck/fields/test_related.yml b/test-data/typecheck/fields/test_related.yml index 3951db3..c310277 100644 --- a/test-data/typecheck/fields/test_related.yml +++ b/test-data/typecheck/fields/test_related.yml @@ -519,4 +519,28 @@ class Book(Entry): pass class Article(Entry): - pass \ No newline at end of file + pass + + +- case: test_related_fields_returned_as_descriptors_from_model_class + main: | + from myapp.models import Author, Blog, Publisher, Profile + reveal_type(Author.blogs) # N: Revealed type is 'django.db.models.fields.related_descriptors.ManyToManyDescriptor' + reveal_type(Blog.publisher) # N: Revealed type is 'django.db.models.fields.related_descriptors.ForwardManyToOneDescriptor' + reveal_type(Publisher.profile) # N: Revealed type is 'django.db.models.fields.related_descriptors.ForwardOneToOneDescriptor' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class Profile(models.Model): + pass + class Publisher(models.Model): + profile = models.OneToOneField(Profile, on_delete=models.CASCADE) + class Blog(models.Model): + publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE) + class Author(models.Model): + blogs = models.ManyToManyField(Blog) + file = models.FileField()