make ValuesQuerySet have proper Collection generic type (#140)

This commit is contained in:
Maxim Kurnikov
2019-08-24 18:24:21 +03:00
committed by GitHub
parent 5fc39ff110
commit fc9a335dfd
11 changed files with 93 additions and 88 deletions

View File

@@ -6,6 +6,7 @@ from django.db.models.sql.compiler import SQLCompiler
from django.db.models import Q, QuerySet
from django.db.models.fields import Field
from django.db.models.query import _BaseQuerySet
_OutputField = Union[Field, str]
@@ -125,7 +126,7 @@ class Subquery(Expression):
template: str = ...
queryset: QuerySet = ...
extra: Dict[Any, Any] = ...
def __init__(self, queryset: QuerySet, output_field: Optional[_OutputField] = ..., **extra: Any) -> None: ...
def __init__(self, queryset: _BaseQuerySet, output_field: Optional[_OutputField] = ..., **extra: Any) -> None: ...
class Exists(Subquery):
negated: bool = ...

View File

@@ -28,9 +28,9 @@ from django.db.models import Manager
from django.db.models.query_utils import Q as Q
_T = TypeVar("_T", bound=models.Model, covariant=True)
_QS = TypeVar("_QS", bound="QuerySet")
_QS = TypeVar("_QS", bound="_BaseQuerySet")
class QuerySet(Generic[_T], Collection[_T], Sized):
class _BaseQuerySet(Generic[_T], Sized):
query: Query
def __init__(
self,
@@ -42,21 +42,13 @@ class QuerySet(Generic[_T], Collection[_T], Sized):
@classmethod
def as_manager(cls) -> Manager[Any]: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ...
def __contains__(self, x: object) -> bool: ...
@overload
def __getitem__(self, i: int) -> _T: ...
@overload
def __getitem__(self, s: slice) -> QuerySet[_T]: ...
def __bool__(self) -> bool: ...
def __class_getitem__(cls, item: Type[_T]):
pass
def __getstate__(self) -> Dict[str, Any]: ...
# __and__ and __or__ ignore the other QuerySet's _Row type parameter
# because they use the same row type as the self QuerySet.
# Technically, the other QuerySet must be of the same type _T, but _T is covariant
def __and__(self, other: QuerySet[_T]) -> QuerySet[_T]: ...
def __or__(self, other: QuerySet[_T]) -> QuerySet[_T]: ...
def __and__(self: _QS, other: _BaseQuerySet[_T]) -> _QS: ...
def __or__(self: _QS, other: _BaseQuerySet[_T]) -> _QS: ...
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
def get(self, *args: Any, **kwargs: Any) -> _T: ...
@@ -94,22 +86,22 @@ class QuerySet(Generic[_T], Collection[_T], Sized):
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: Optional[datetime.tzinfo] = ...
) -> ValuesQuerySet[_T, datetime.datetime]: ...
def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def complex_filter(self, filter_obj: Any) -> QuerySet[_T]: ...
def none(self: _QS) -> _QS: ...
def all(self: _QS) -> _QS: ...
def filter(self: _QS, *args: Any, **kwargs: Any) -> _QS: ...
def exclude(self: _QS, *args: Any, **kwargs: Any) -> _QS: ...
def complex_filter(self, filter_obj: Any) -> _QS: ...
def count(self) -> int: ...
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
def select_for_update(self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> QuerySet[_T]: ...
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
def union(self: _QS, *other_qs: Any, all: bool = ...) -> _QS: ...
def intersection(self: _QS, *other_qs: Any) -> _QS: ...
def difference(self: _QS, *other_qs: Any) -> _QS: ...
def select_for_update(self: _QS, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> _QS: ...
def select_related(self: _QS, *fields: Any) -> _QS: ...
def prefetch_related(self: _QS, *lookups: Any) -> _QS: ...
# TODO: return type
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[Any]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
def order_by(self: _QS, *field_names: Any) -> _QS: ...
def distinct(self: _QS, *field_names: Any) -> _QS: ...
# extra() return type won't be supported any time soon
def extra(
self,
@@ -120,10 +112,10 @@ class QuerySet(Generic[_T], Collection[_T], Sized):
order_by: Optional[Sequence[str]] = ...,
select_params: Optional[Sequence[Any]] = ...,
) -> QuerySet[Any]: ...
def reverse(self) -> QuerySet[_T]: ...
def defer(self, *fields: Any) -> QuerySet[_T]: ...
def only(self, *fields: Any) -> QuerySet[_T]: ...
def using(self, alias: Optional[str]) -> QuerySet[_T]: ...
def reverse(self: _QS) -> _QS: ...
def defer(self: _QS, *fields: Any) -> _QS: ...
def only(self: _QS, *fields: Any) -> _QS: ...
def using(self: _QS, alias: Optional[str]) -> _QS: ...
@property
def ordered(self) -> bool: ...
@property
@@ -132,10 +124,18 @@ class QuerySet(Generic[_T], Collection[_T], Sized):
# TODO: remove when django adds __class_getitem__ methods
def __getattr__(self, item: str) -> Any: ...
class QuerySet(_BaseQuerySet[_T], Collection[_T], Sized):
def __iter__(self) -> Iterator[_T]: ...
def __contains__(self, x: object) -> bool: ...
@overload
def __getitem__(self, i: int) -> _T: ...
@overload
def __getitem__(self: _QS, s: slice) -> _QS: ...
_Row = TypeVar("_Row", covariant=True)
class BaseIterable(Sequence[_Row]):
def __init__(self, queryset: QuerySet, chunked_fetch: bool = ..., chunk_size: int = ...): ...
def __init__(self, queryset: _BaseQuerySet, chunked_fetch: bool = ..., chunk_size: int = ...): ...
def __iter__(self) -> Iterator[_Row]: ...
def __contains__(self, x: object) -> bool: ...
def __len__(self) -> int: ...
@@ -152,53 +152,19 @@ class NamedValuesListIterable(ValuesListIterable): ...
class FlatValuesListIterable(BaseIterable):
def __iter__(self) -> Iterator[Any]: ...
class ValuesQuerySet(Generic[_T, _Row], QuerySet[_T], Collection[_Row], Sized):
class ValuesQuerySet(_BaseQuerySet[_T], Collection[_Row], Sized):
def __contains__(self, x: object) -> bool: ...
def __iter__(self) -> Iterator[_Row]: ... # type: ignore
@overload # type: ignore
def __getitem__(self, i: int) -> _Row: ...
@overload
def __getitem__(self, s: slice) -> ValuesQuerySet[_T, _Row]: ...
# Technically, the other QuerySet must be of the same type _T, but _T is covariant
def __and__(self, other: ValuesQuerySet[_T, _Row]) -> ValuesQuerySet[_T, _Row]: ... # type: ignore
def __or__(self, other: ValuesQuerySet[_T, _Row]) -> ValuesQuerySet[_T, _Row]: ... # type: ignore
def __getitem__(self: _QS, s: slice) -> _QS: ...
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ... # type: ignore
def get(self, *args: Any, **kwargs: Any) -> _Row: ... # type: ignore
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ... # type: ignore
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ... # type: ignore
def first(self) -> Optional[_Row]: ... # type: ignore
def last(self) -> Optional[_Row]: ... # type: ignore
def none(self) -> ValuesQuerySet[_T, _Row]: ...
def all(self) -> ValuesQuerySet[_T, _Row]: ...
def filter(self, *args: Any, **kwargs: Any) -> ValuesQuerySet[_T, _Row]: ...
def exclude(self, *args: Any, **kwargs: Any) -> ValuesQuerySet[_T, _Row]: ...
def complex_filter(self, filter_obj: Any) -> ValuesQuerySet[_T, _Row]: ...
def count(self) -> int: ...
def union(self, *other_qs: Any, all: bool = ...) -> ValuesQuerySet[_T, _Row]: ...
def intersection(self, *other_qs: Any) -> ValuesQuerySet[_T, _Row]: ...
def difference(self, *other_qs: Any) -> ValuesQuerySet[_T, _Row]: ...
def select_for_update(
self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...
) -> ValuesQuerySet[_T, _Row]: ...
def select_related(self, *fields: Any) -> ValuesQuerySet[_T, _Row]: ...
def prefetch_related(self, *lookups: Any) -> ValuesQuerySet[_T, _Row]: ...
# TODO: return type
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[Any]: ...
def order_by(self, *field_names: Any) -> ValuesQuerySet[_T, _Row]: ...
def distinct(self, *field_names: Any) -> ValuesQuerySet[_T, _Row]: ...
# extra() return type won't be supported any time soon
def extra(
self,
select: Optional[Dict[str, Any]] = ...,
where: Optional[List[str]] = ...,
params: Optional[List[Any]] = ...,
tables: Optional[List[str]] = ...,
order_by: Optional[Sequence[str]] = ...,
select_params: Optional[Sequence[Any]] = ...,
) -> QuerySet[Any]: ...
def reverse(self) -> ValuesQuerySet[_T, _Row]: ...
def defer(self, *fields: Any) -> ValuesQuerySet[_T, _Row]: ...
def only(self, *fields: Any) -> ValuesQuerySet[_T, _Row]: ...
def using(self, alias: Optional[str]) -> ValuesQuerySet[_T, _Row]: ...
class RawQuerySet(Iterable[_T], Sized):
query: RawQuery

View File

@@ -6,7 +6,7 @@ from uuid import UUID
from django.core.files.base import File
from django.db.models.base import Model
from django.db.models.manager import Manager
from django.db.models.query import QuerySet
from django.db.models.query import QuerySet, _BaseQuerySet
from django.db.models.query_utils import Q
from django.forms.fields import CharField, ChoiceField, Field
from django.forms.forms import BaseForm, DeclarativeFieldsMetaclass
@@ -262,7 +262,7 @@ class ModelMultipleChoiceField(ModelChoiceField):
widget: Any = ...
hidden_widget: Any = ...
default_error_messages: Any = ...
def __init__(self, queryset: QuerySet, **kwargs: Any) -> None: ...
def __init__(self, queryset: _BaseQuerySet, **kwargs: Any) -> None: ...
def _get_foreign_key(
parent_model: Type[Model], model: Type[Model], fk_name: Optional[str] = ..., can_fail: bool = ...

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Type
from django.core.paginator import Paginator
from django.db.models.query import QuerySet
from django.db.models.query import QuerySet, _BaseQuerySet
from django.views.generic.base import ContextMixin, TemplateResponseMixin, View
from django.db.models import Model
@@ -22,14 +22,14 @@ class MultipleObjectMixin(ContextMixin):
object_list: QuerySet = ...
def get_queryset(self) -> QuerySet: ...
def get_ordering(self) -> Sequence[str]: ...
def paginate_queryset(self, queryset: QuerySet, page_size: int) -> Tuple[Paginator, int, QuerySet, bool]: ...
def get_paginate_by(self, queryset: QuerySet) -> Optional[int]: ...
def paginate_queryset(self, queryset: _BaseQuerySet, page_size: int) -> Tuple[Paginator, int, QuerySet, bool]: ...
def get_paginate_by(self, queryset: _BaseQuerySet) -> Optional[int]: ...
def get_paginator(
self, queryset: QuerySet, per_page: int, orphans: int = ..., allow_empty_first_page: bool = ..., **kwargs: Any
) -> Paginator: ...
def get_paginate_orphans(self) -> int: ...
def get_allow_empty(self) -> bool: ...
def get_context_object_name(self, object_list: QuerySet) -> Optional[str]: ...
def get_context_object_name(self, object_list: _BaseQuerySet) -> Optional[str]: ...
class BaseListView(MultipleObjectMixin, View):
def render_to_response(self, context: Dict[str, Any], **response_kwargs: Any) -> HttpResponse: ...

View File

@@ -135,7 +135,15 @@ class DjangoFieldsContext:
related_model_cls = field.field.model
if isinstance(related_model_cls, str):
related_model_cls = self.django_context.apps_registry.get_model(related_model_cls)
if related_model_cls == 'self':
# same model
related_model_cls = field.model
elif '.' not in related_model_cls:
# same file model
related_model_fullname = field.model.__module__ + '.' + related_model_cls
related_model_cls = self.django_context.get_model_class_by_fullname(related_model_fullname)
else:
related_model_cls = self.django_context.apps_registry.get_model(related_model_cls)
return related_model_cls

View File

@@ -43,7 +43,9 @@ def return_proper_field_type_from_get_field(ctx: MethodContext, django_context:
try:
field = model_cls._meta.get_field(field_name)
except FieldDoesNotExist as exc:
ctx.api.fail(exc.args[0], ctx.context)
# if model is abstract, do not raise exception, skip false positives
if not model_cls._meta.abstract:
ctx.api.fail(exc.args[0], ctx.context)
return AnyType(TypeOfAny.from_error)
field_fullname = helpers.get_class_fullname(field.__class__)

View File

@@ -95,7 +95,8 @@ IGNORED_ERRORS = {
'basic': [
'Unexpected keyword argument "unknown_kwarg" for "refresh_from_db" of "Model"',
'Unexpected attribute "foo" for model "Article"',
'has no attribute "touched"'
'has no attribute "touched"',
'Incompatible types in assignment (expression has type "Type[CustomQuerySet]"'
],
'backends': [
'"DatabaseError" has no attribute "pgcode"'
@@ -127,7 +128,7 @@ IGNORED_ERRORS = {
'base class "HttpRequest" defined the type as "QueryDict")'
],
'dates': [
'Too few arguments for "dates" of "QuerySet"',
'Too few arguments for "dates" of',
],
'defer': [
'Too many arguments for "refresh_from_db" of "Model"'
@@ -188,7 +189,7 @@ IGNORED_ERRORS = {
'Argument 1 to "append" of "list" has incompatible type "None"; expected "str"'
],
'lookup': [
'Unexpected keyword argument "headline__startswith" for "in_bulk" of "QuerySet"',
'Unexpected keyword argument "headline__startswith" for "in_bulk" of',
'is called with more than one field'
],
'messages_tests': [
@@ -264,9 +265,10 @@ IGNORED_ERRORS = {
'"Person" has no attribute "houses_lst"',
'"Book" has no attribute "first_authors"',
'"Book" has no attribute "the_authors"',
'Incompatible types in assignment (expression has type "List[Room]", variable has type "QuerySet[Room]")',
'Incompatible types in assignment (expression has type "List[Room]", variable has type "Manager[Room]")',
'Item "Room" of "Optional[Room]" has no attribute "house_attr"',
'Item "Room" of "Optional[Room]" has no attribute "main_room_of_attr"'
'Item "Room" of "Optional[Room]" has no attribute "main_room_of_attr"',
'Argument 2 to "Prefetch" has incompatible type "ValuesQuerySet'
],
'proxy_models': [
'Incompatible types in assignment',
@@ -275,12 +277,13 @@ IGNORED_ERRORS = {
'queries': [
'Incompatible types in assignment (expression has type "None", variable has type "str")',
'Invalid index type "Optional[str]" for "Dict[str, int]"; expected type "str"',
'Unsupported operand types for & ("QuerySet[Author]" and "QuerySet[Tag]")',
'Unsupported operand types for | ("QuerySet[Author]" and "QuerySet[Tag]")',
'Unsupported operand types for & ("Manager[Author]" and "Manager[Tag]")',
'Unsupported operand types for | ("Manager[Author]" and "Manager[Tag]")',
'ObjectA',
'ObjectB',
'ObjectC',
"'flat' and 'named' can't be used together",
'"Collection[Any]" has no attribute "explain"'
],
'requests': [
'Incompatible types in assignment (expression has type "Dict[str, str]", variable has type "QueryDict")'
@@ -294,6 +297,9 @@ IGNORED_ERRORS = {
'signals': [
'Argument 1 to "append" of "list" has incompatible type "Tuple[Any, Any, Optional[Any], Any]";'
],
'sites_framework': [
'expression has type "CurrentSiteManager[CustomArticle]", base class "AbstractArticle"'
],
'syndication_tests': [
'List or tuple expected as variable arguments'
],

View File

@@ -3,14 +3,14 @@
from myapp.models import Blog
qs = Blog.objects.all()
reveal_type(qs) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.Blog*]'
reveal_type(qs) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.Blog]'
reveal_type(qs.get(id=1)) # N: Revealed type is 'myapp.models.Blog*'
reveal_type(iter(qs)) # N: Revealed type is 'typing.Iterator[myapp.models.Blog*]'
reveal_type(qs.iterator()) # N: Revealed type is 'typing.Iterator[myapp.models.Blog*]'
reveal_type(qs.first()) # N: Revealed type is 'Union[myapp.models.Blog*, None]'
reveal_type(qs.earliest()) # N: Revealed type is 'myapp.models.Blog*'
reveal_type(qs[0]) # N: Revealed type is 'myapp.models.Blog*'
reveal_type(qs[:9]) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.Blog*]'
reveal_type(qs[:9]) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.Blog]'
reveal_type(qs.in_bulk()) # N: Revealed type is 'builtins.dict[Any, myapp.models.Blog*]'
# .dates / .datetimes
@@ -18,7 +18,7 @@
reveal_type(Blog.objects.datetimes("created_at", "day")) # N: Revealed type is 'django.db.models.query.ValuesQuerySet[myapp.models.Blog*, datetime.datetime]'
# AND-ing QuerySets
reveal_type(Blog.objects.all() & Blog.objects.all()) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.Blog*]'
reveal_type(Blog.objects.all() & Blog.objects.all()) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.Blog]'
installed_apps:
- myapp
files:

View File

@@ -53,7 +53,7 @@
text = models.CharField(max_length=100)
blog = models.ForeignKey(to=Blog, on_delete=models.CASCADE)
- case: values_list_flat_true
- case: values_list_flat_true_methods
main: |
from myapp.models import MyUser, MyUser2
reveal_type(MyUser.objects.values_list('name', flat=True).get()) # N: Revealed type is 'builtins.str*'
@@ -194,6 +194,8 @@
from myapp.models import Blog, Publisher
reveal_type(Blog.objects.values_list('id', flat=True)) # N: Revealed type is 'django.db.models.query.ValuesQuerySet[myapp.models.Blog, builtins.int]'
reveal_type(Blog.objects.values_list('publisher_id', flat=True)) # N: Revealed type is 'django.db.models.query.ValuesQuerySet[myapp.models.Blog, builtins.int]'
# is Iterable[int]
reveal_type(list(Blog.objects.values_list('id', flat=True))) # N: Revealed type is 'builtins.list[builtins.int*]'
installed_apps:
- myapp
files:

View File

@@ -308,7 +308,7 @@
main: |
from myapp.models import User
reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*'
reveal_type(User.objects.select_related()) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.User*]'
reveal_type(User.objects.select_related()) # N: Revealed type is 'myapp.models.User_MyManager[myapp.models.User]'
installed_apps:
- myapp
files:

View File

@@ -36,3 +36,23 @@
name = models.CharField(max_length=100)
age = models.IntegerField()
to_user = models.ForeignKey('self', on_delete=models.SET_NULL)
- case: get_field_with_abstract_inheritance
main: |
from myapp.models import AbstractModel
class MyModel(AbstractModel):
pass
reveal_type(MyModel._meta.get_field('field')) # N: Revealed type is 'Any'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
from django.contrib.postgres.fields import ArrayField
class AbstractModel(models.Model):
class Meta:
abstract = True
class MyModel(AbstractModel):
field = ArrayField(models.IntegerField(), default=[])