Return corresponding descriptors for some fields in class access (#137)

* return corresponding descriptors for some related fields in class access

* return corresponding descriptors for file fields in class access

* fix tests
This commit is contained in:
Maxim Kurnikov
2019-08-23 04:03:03 +03:00
committed by GitHub
parent 656105bab2
commit 09767210ec
3 changed files with 89 additions and 6 deletions

View File

@@ -1,6 +1,5 @@
from typing import Any, Callable, List, Optional, Type, Union, Tuple, Iterable
from typing import Any, Callable, List, Optional, Type, Union, Tuple, Iterable, overload, TypeVar
from django.core.checks.messages import Error
from django.core.files.base import File
from django.core.files.images import ImageFile
from django.core.files.storage import FileSystemStorage, Storage
@@ -34,6 +33,8 @@ class FileDescriptor:
def __set__(self, instance: Model, value: Optional[Any]) -> None: ...
def __get__(self, instance: Optional[Model], cls: Type[Model] = ...) -> Union[FieldFile, FileDescriptor]: ...
_T = TypeVar("_T", bound="Field")
class FileField(Field):
storage: Any = ...
upload_to: Union[str, Callable] = ...
@@ -63,6 +64,15 @@ class FileField(Field):
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
# class access
@overload # type: ignore
def __get__(self, instance: None, owner) -> FileDescriptor: ...
# Model instance access
@overload
def __get__(self, instance: Model, owner) -> Any: ...
# non-Model instances
@overload
def __get__(self: _T, instance, owner) -> _T: ...
def generate_filename(self, instance: Optional[Model], filename: str) -> str: ...
class ImageFileDescriptor(FileDescriptor):
@@ -82,4 +92,13 @@ class ImageField(FileField):
height_field: Optional[str] = ...,
**kwargs: Any
) -> None: ...
# class access
@overload # type: ignore
def __get__(self, instance: None, owner) -> ImageFileDescriptor: ...
# Model instance access
@overload
def __get__(self, instance: Model, owner) -> Any: ...
# non-Model instances
@overload
def __get__(self: _T, instance, owner) -> _T: ...
def update_dimension_fields(self, instance: Model, force: bool = ..., *args: Any, **kwargs: Any) -> None: ...

View File

@@ -1,4 +1,18 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING, Tuple, Type, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from uuid import UUID
from django.db.models.expressions import Combinable
@@ -25,7 +39,7 @@ if TYPE_CHECKING:
from django.db.models.manager import RelatedManager
_T = TypeVar("_T", bound=models.Model)
_F = TypeVar("_F", bound=models.Field)
_Choice = Tuple[Any, str]
_ChoiceNamedGroup = Tuple[str, Iterable[_Choice]]
_FieldChoices = Iterable[Union[_Choice, _ChoiceNamedGroup]]
@@ -127,6 +141,15 @@ class ForeignKey(ForeignObject[_ST, _GT]):
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
# class access
@overload # type: ignore
def __get__(self, instance: None, owner) -> ForwardManyToOneDescriptor: ...
# Model instance access
@overload
def __get__(self, instance: Model, owner) -> _GT: ...
# non-Model instances
@overload
def __get__(self: _F, instance, owner) -> _F: ...
class OneToOneField(RelatedField[_ST, _GT]):
_pyi_private_set_type: Union[Any, Combinable]
@@ -163,6 +186,15 @@ class OneToOneField(RelatedField[_ST, _GT]):
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
# class access
@overload # type: ignore
def __get__(self, instance: None, owner) -> ForwardOneToOneDescriptor: ...
# Model instance access
@overload
def __get__(self, instance: Model, owner) -> _GT: ...
# non-Model instances
@overload
def __get__(self: _F, instance, owner) -> _F: ...
class ManyToManyField(RelatedField[_ST, _GT]):
_pyi_private_set_type: Sequence[Any]
@@ -206,7 +238,15 @@ class ManyToManyField(RelatedField[_ST, _GT]):
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
) -> None: ...
def __get__(self, instance, owner) -> _GT: ... # type: ignore
# class access
@overload # type: ignore
def __get__(self, instance: None, owner) -> ManyToManyDescriptor: ...
# Model instance access
@overload
def __get__(self, instance: Model, owner) -> _GT: ...
# non-Model instances
@overload
def __get__(self: _F, instance, owner) -> _F: ...
def get_path_info(self, filtered_relation: None = ...) -> List[PathInfo]: ...
def get_reverse_path_info(self, filtered_relation: None = ...) -> List[PathInfo]: ...
def contribute_to_related_class(self, cls: Type[Model], related: RelatedField) -> None: ...

View File

@@ -520,3 +520,27 @@
pass
class Article(Entry):
pass
- case: test_related_fields_returned_as_descriptors_from_model_class
main: |
from myapp.models import Author, Blog, Publisher, Profile
reveal_type(Author.blogs) # N: Revealed type is 'django.db.models.fields.related_descriptors.ManyToManyDescriptor'
reveal_type(Blog.publisher) # N: Revealed type is 'django.db.models.fields.related_descriptors.ForwardManyToOneDescriptor'
reveal_type(Publisher.profile) # N: Revealed type is 'django.db.models.fields.related_descriptors.ForwardOneToOneDescriptor'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class Profile(models.Model):
pass
class Publisher(models.Model):
profile = models.OneToOneField(Profile, on_delete=models.CASCADE)
class Blog(models.Model):
publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE)
class Author(models.Model):
blogs = models.ManyToManyField(Blog)
file = models.FileField()