diff --git a/django-stubs/db/models/functions/datetime.pyi b/django-stubs/db/models/functions/datetime.pyi index 13b84e3..3ee040a 100644 --- a/django-stubs/db/models/functions/datetime.pyi +++ b/django-stubs/db/models/functions/datetime.pyi @@ -1,9 +1,10 @@ from datetime import datetime -from typing import Any, List, Optional, Set, Tuple, Union +from typing import Any, List, Optional, Set, Tuple, Union, Callable, Dict +from django.db import models from django.db.backends.sqlite3.base import DatabaseWrapper from django.db.models import Func, Transform -from django.db.models.expressions import Col, Expression +from django.db.models.expressions import Col, Expression, Combinable from django.db.models.fields import Field from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.query import Query @@ -18,7 +19,7 @@ class Extract(TimezoneMixin, Transform): convert_value: Callable extra: Dict[Any, Any] is_summary: bool - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] lookup_name: Optional[str] = ... output_field: Any = ... tzinfo: None = ... @@ -46,7 +47,7 @@ class ExtractYear(Extract): convert_value: Callable extra: Dict[Any, Any] is_summary: bool - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] tzinfo: None lookup_name: str = ... @@ -55,7 +56,7 @@ class ExtractMonth(Extract): convert_value: Callable extra: Dict[Any, Any] is_summary: bool - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] tzinfo: None lookup_name: str = ... @@ -64,7 +65,7 @@ class ExtractDay(Extract): convert_value: Callable extra: Dict[Any, Any] is_summary: bool - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] tzinfo: None lookup_name: str = ... @@ -73,7 +74,7 @@ class ExtractWeek(Extract): convert_value: Callable extra: Dict[Any, Any] is_summary: bool - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] tzinfo: None lookup_name: str = ... @@ -82,7 +83,7 @@ class ExtractWeekDay(Extract): convert_value: Callable extra: Dict[Any, Any] is_summary: bool - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] tzinfo: None lookup_name: str = ... @@ -91,7 +92,7 @@ class ExtractQuarter(Extract): convert_value: Callable extra: Dict[Any, Any] is_summary: bool - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] tzinfo: None lookup_name: str = ... @@ -100,7 +101,7 @@ class ExtractHour(Extract): convert_value: Callable extra: Dict[Any, Any] is_summary: bool - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] tzinfo: None lookup_name: str = ... @@ -109,7 +110,7 @@ class ExtractMinute(Extract): convert_value: Callable extra: Dict[Any, Any] is_summary: bool - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] tzinfo: None lookup_name: str = ... @@ -118,7 +119,7 @@ class ExtractSecond(Extract): convert_value: Callable extra: Dict[Any, Any] is_summary: bool - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] tzinfo: None lookup_name: str = ... @@ -156,7 +157,7 @@ class TruncBase(TimezoneMixin, Transform): def convert_value( self, value: datetime, - expression: django.db.models.functions.TruncBase, + expression: models.functions.TruncBase, connection: DatabaseWrapper, ) -> datetime: ... @@ -165,10 +166,10 @@ class Trunc(TruncBase): extra: Dict[Any, Any] is_summary: bool output_field: Union[ - django.db.models.fields.DateTimeCheckMixin, - django.db.models.fields.IntegerField, + models.fields.DateTimeCheckMixin, + models.fields.IntegerField, ] - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] tzinfo: None kind: str = ... def __init__( @@ -184,8 +185,8 @@ class TruncYear(TruncBase): contains_aggregate: bool extra: Dict[Any, Any] is_summary: bool - output_field: django.db.models.fields.DateTimeCheckMixin - source_expressions: List[django.db.models.expressions.Combinable] + output_field: models.fields.DateTimeCheckMixin + source_expressions: List[Combinable] tzinfo: None kind: str = ... @@ -193,8 +194,8 @@ class TruncQuarter(TruncBase): contains_aggregate: bool extra: Dict[Any, Any] is_summary: bool - output_field: django.db.models.fields.DateTimeCheckMixin - source_expressions: List[django.db.models.expressions.Combinable] + output_field: models.fields.DateTimeCheckMixin + source_expressions: List[Combinable] tzinfo: None kind: str = ... @@ -202,8 +203,8 @@ class TruncMonth(TruncBase): contains_aggregate: bool extra: Dict[Any, Any] is_summary: bool - output_field: django.db.models.fields.DateTimeCheckMixin - source_expressions: List[django.db.models.expressions.Combinable] + output_field: models.fields.DateTimeCheckMixin + source_expressions: List[Combinable] tzinfo: None kind: str = ... @@ -211,8 +212,8 @@ class TruncWeek(TruncBase): contains_aggregate: bool extra: Dict[Any, Any] is_summary: bool - output_field: django.db.models.fields.DateTimeCheckMixin - source_expressions: List[django.db.models.expressions.Combinable] + output_field: models.fields.DateTimeCheckMixin + source_expressions: List[Combinable] tzinfo: None kind: str = ... @@ -220,8 +221,8 @@ class TruncDay(TruncBase): contains_aggregate: bool extra: Dict[Any, Any] is_summary: bool - output_field: django.db.models.fields.DateTimeCheckMixin - source_expressions: List[django.db.models.expressions.Combinable] + output_field: models.fields.DateTimeCheckMixin + source_expressions: List[Combinable] tzinfo: None kind: str = ... @@ -229,11 +230,11 @@ class TruncDate(TruncBase): contains_aggregate: bool extra: Dict[Any, Any] is_summary: bool - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] tzinfo: None kind: str = ... lookup_name: str = ... - output_field: django.db.models.fields.TimeField = ... + output_field: models.fields.TimeField = ... def as_sql( self, compiler: SQLCompiler, connection: DatabaseWrapper ) -> Tuple[str, List[Any]]: ... @@ -242,11 +243,11 @@ class TruncTime(TruncBase): contains_aggregate: bool extra: Dict[Any, Any] is_summary: bool - source_expressions: List[django.db.models.expressions.Combinable] + source_expressions: List[Combinable] tzinfo: None kind: str = ... lookup_name: str = ... - output_field: django.db.models.fields.DateField = ... + output_field: models.fields.DateField = ... def as_sql( self, compiler: SQLCompiler, connection: DatabaseWrapper ) -> Tuple[str, List[Any]]: ... @@ -255,8 +256,8 @@ class TruncHour(TruncBase): contains_aggregate: bool extra: Dict[Any, Any] is_summary: bool - output_field: django.db.models.fields.DateTimeCheckMixin - source_expressions: List[django.db.models.expressions.Combinable] + output_field: models.fields.DateTimeCheckMixin + source_expressions: List[Combinable] tzinfo: None kind: str = ... @@ -264,8 +265,8 @@ class TruncMinute(TruncBase): contains_aggregate: bool extra: Dict[Any, Any] is_summary: bool - output_field: django.db.models.fields.DateTimeCheckMixin - source_expressions: List[django.db.models.expressions.Combinable] + output_field: models.fields.DateTimeCheckMixin + source_expressions: List[Combinable] tzinfo: None kind: str = ... @@ -273,7 +274,7 @@ class TruncSecond(TruncBase): contains_aggregate: bool extra: Dict[Any, Any] is_summary: bool - output_field: django.db.models.fields.DateTimeCheckMixin - source_expressions: List[django.db.models.expressions.Combinable] + output_field: models.fields.DateTimeCheckMixin + source_expressions: List[Combinable] tzinfo: None kind: str = ... diff --git a/django-stubs/db/models/functions/text.pyi b/django-stubs/db/models/functions/text.pyi index de3378e..6117de8 100644 --- a/django-stubs/db/models/functions/text.pyi +++ b/django-stubs/db/models/functions/text.pyi @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, Dict, Callable from django.db.backends.sqlite3.base import DatabaseWrapper from django.db.models import Func, Transform diff --git a/django-stubs/db/models/functions/window.pyi b/django-stubs/db/models/functions/window.pyi index 1faaa40..10eddd2 100644 --- a/django-stubs/db/models/functions/window.pyi +++ b/django-stubs/db/models/functions/window.pyi @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Dict, List from django.db.models import Func diff --git a/django-stubs/db/models/query.pyi b/django-stubs/db/models/query.pyi index 9a6d822..c6e4eb6 100644 --- a/django-stubs/db/models/query.pyi +++ b/django-stubs/db/models/query.pyi @@ -5,11 +5,12 @@ from datetime import date, datetime from decimal import Decimal from itertools import chain from typing import (Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, - Type, Union) + Type, Union, Generic, TypeVar, overload) from unittest.mock import MagicMock from uuid import UUID from django.contrib.contenttypes.fields import GenericForeignKey +from django.db import models from django.db.models.base import Model, ModelState from django.db.models.expressions import Expression from django.db.models.fields import Field @@ -38,7 +39,7 @@ class BaseIterable: class ModelIterable(BaseIterable): chunk_size: int chunked_fetch: bool - queryset: django.db.models.query.QuerySet + queryset: QuerySet def __iter__(self) -> Iterator[Model]: ... @@ -46,7 +47,7 @@ class ModelIterable(BaseIterable): class ValuesIterable(BaseIterable): chunk_size: int chunked_fetch: bool - queryset: django.db.models.query.QuerySet + queryset: QuerySet def __iter__(self) -> Iterator[Dict[str, Optional[Union[int, str]]]]: ... @@ -54,7 +55,7 @@ class ValuesIterable(BaseIterable): class ValuesListIterable(BaseIterable): chunk_size: int chunked_fetch: bool - queryset: django.db.models.query.QuerySet + queryset: QuerySet def __iter__(self) -> Union[chain, map]: ... @@ -78,9 +79,12 @@ class FlatValuesListIterable(BaseIterable): def __iter__(self) -> Iterator[Any]: ... -class QuerySet: - model: Optional[Type[django.db.models.base.Model]] = ... - query: django.db.models.sql.query.Query = ... +_T = TypeVar('_T', bound=models.Model) + + +class QuerySet(Generic[_T]): + model: Optional[Type[models.Model]] = ... + query: models.sql.Query = ... def __init__( self, @@ -105,21 +109,28 @@ class QuerySet: ModelState, ], ], - ) -> QuerySet: ... + ) -> QuerySet[_T]: ... def __len__(self) -> int: ... - def __iter__(self) -> Any: ... + def __iter__(self) -> Iterator[_T]: ... def __bool__(self) -> bool: ... - def __getitem__(self, k: Union[int, slice, str]) -> Any: ... + @overload + def __getitem__(self, k: slice) -> QuerySet[_T]: ... + + @overload + def __getitem__(self, k: int) -> _T: ... + + @overload + def __getitem__(self, k: str) -> Any: ... def __and__(self, other: QuerySet) -> QuerySet: ... def __or__(self, other: QuerySet) -> QuerySet: ... - def iterator(self, chunk_size: int = ...) -> Iterator[Any]: ... + def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ... def aggregate( self, *args: Any, **kwargs: Any @@ -136,19 +147,19 @@ class QuerySet: str, ]: ... - def create(self, **kwargs: Any) -> Model: ... + def create(self, **kwargs: Any) -> _T: ... def bulk_create( self, objs: Union[Iterator[Any], List[Model]], batch_size: Optional[int] = ..., - ) -> List[Model]: ... + ) -> List[_T]: ... def get_or_create( self, defaults: Optional[Union[Dict[str, date], Dict[str, Model]]] = ..., **kwargs: Any - ) -> Tuple[Model, bool]: ... + ) -> Tuple[_T, bool]: ... def update_or_create( self, @@ -161,19 +172,19 @@ class QuerySet: ] ] = ..., **kwargs: Any - ) -> Tuple[Model, bool]: ... + ) -> Tuple[_T, bool]: ... def earliest( self, *fields: Any, field_name: Optional[Any] = ... - ) -> Model: ... + ) -> _T: ... def latest( self, *fields: Any, field_name: Optional[Any] = ... - ) -> Model: ... + ) -> _T: ... - def first(self) -> Optional[Union[Dict[str, int], Model]]: ... + def first(self) -> Optional[Union[Dict[str, int], _T]]: ... - def last(self) -> Optional[Model]: ... + def last(self) -> Optional[_T]: ... def in_bulk( self, id_list: Any = ..., *, field_name: str = ... @@ -220,40 +231,40 @@ class QuerySet: self, field_name: str, kind: str, order: str = ..., tzinfo: None = ... ) -> QuerySet: ... - def none(self) -> QuerySet: ... + def none(self) -> QuerySet[_T]: ... - def all(self) -> QuerySet: ... + def all(self) -> QuerySet[_T]: ... - def filter(self, *args: Any, **kwargs: Any) -> QuerySet: ... + def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ... - def exclude(self, *args: Any, **kwargs: Any) -> QuerySet: ... + def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ... def complex_filter( self, filter_obj: Union[ Dict[str, datetime], Dict[str, QuerySet], Q, MagicMock ], - ) -> QuerySet: ... + ) -> QuerySet[_T]: ... - def union(self, *other_qs: Any, all: bool = ...) -> QuerySet: ... + def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ... - def intersection(self, *other_qs: Any) -> QuerySet: ... + def intersection(self, *other_qs: Any) -> QuerySet[_T]: ... - def difference(self, *other_qs: Any) -> QuerySet: ... + def difference(self, *other_qs: Any) -> QuerySet[_T]: ... def select_for_update( self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ... ) -> QuerySet: ... - def select_related(self, *fields: Any) -> QuerySet: ... + def select_related(self, *fields: Any) -> QuerySet[_T]: ... - def prefetch_related(self, *lookups: Any) -> QuerySet: ... + def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ... - def annotate(self, *args: Any, **kwargs: Any) -> QuerySet: ... + def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ... - def order_by(self, *field_names: Any) -> QuerySet: ... + def order_by(self, *field_names: Any) -> QuerySet[_T]: ... - def distinct(self, *field_names: Any) -> QuerySet: ... + def distinct(self, *field_names: Any) -> QuerySet[_T]: ... def extra( self, @@ -265,15 +276,15 @@ class QuerySet: tables: Optional[List[str]] = ..., order_by: Optional[Union[List[str], Tuple[str]]] = ..., select_params: Optional[Union[List[int], List[str], Tuple[int]]] = ..., - ) -> QuerySet: ... + ) -> QuerySet[_T]: ... - def reverse(self) -> QuerySet: ... + def reverse(self) -> QuerySet[_T]: ... - def defer(self, *fields: Any) -> QuerySet: ... + def defer(self, *fields: Any) -> QuerySet[_T]: ... - def only(self, *fields: Any) -> QuerySet: ... + def only(self, *fields: Any) -> QuerySet[_T]: ... - def using(self, alias: Optional[str]) -> QuerySet: ... + def using(self, alias: Optional[str]) -> QuerySet[_T]: ... @property def ordered(self) -> bool: ... @@ -294,13 +305,13 @@ class EmptyQuerySet: class RawQuerySet: columns: List[str] - model_fields: Dict[str, django.db.models.fields.Field] + model_fields: Dict[str, models.Field] raw_query: str = ... - model: Optional[Type[django.db.models.base.Model]] = ... - query: django.db.models.sql.query.RawQuery = ... + model: Optional[Type[models.Model]] = ... + query: models.sql.RawQuery = ... params: Union[ Dict[str, str], - List[datetime.datetime], + List[datetime], List[decimal.Decimal], List[str], Set[str], @@ -359,7 +370,7 @@ class RawQuerySet: class Prefetch: prefetch_through: str = ... prefetch_to: str = ... - queryset: Optional[django.db.models.query.QuerySet] = ... + queryset: Optional[QuerySet] = ... to_attr: Optional[str] = ... def __init__( @@ -415,9 +426,9 @@ class RelatedPopulator: cols_end: int = ... init_list: List[str] = ... reorder_for_init: Optional[operator.itemgetter] = ... - model_cls: Type[django.db.models.base.Model] = ... + model_cls: Type[models.Model] = ... pk_idx: int = ... - related_populators: List[django.db.models.query.RelatedPopulator] = ... + related_populators: List[models.query.RelatedPopulator] = ... local_setter: Callable = ... remote_setter: Callable = ...