Recover after #909 (#925)

* Fix stubs related to `(Async)RequestFactory` and `(Async)Client`

* Revert incorrect removal.

* Allow set as `unique_together`, use shared type alias.

* Revert `Q.__init__` to use only `*args, **kwargs` to remove false-positive with `Q(**{...})`

* Add abstract methods to `HttpResponseBase` to create common interface.

* Remove monkey-patched attributes from `HttpResponseBase` subclasses.

* Add QueryDict mutability checks (+ plugin support)

* Fix lint

* Return back GenericForeignKey to `Options.get_fields`

* Minor fixup

* Make plugin code typecheck with `--warn-unreachable`, minor performance increase.

* Better types for `{unique, index}_together` and Options.

* Fix odd type of `URLResolver.urlconf_name` which isn't a str actually.

* Better types for field migration operations.

* Revert form.files to `MultiValueDict[str, UploadedFile]`

* Compatibility fix (#916)

* Do not assume that `Annotated` is always related to django-stubs (fixes #893)

* Restrict `FormView.get_form` return type to `_FormT` (class type argument). Now it is resolved to `form_class` argument if present, but also errors if it is not subclass of _FormT

* Fix CI (make test runnable on 3.8)

* Fix CI (make test runnable on 3.8 _again_)
This commit is contained in:
sterliakov
2022-04-28 13:01:37 +03:00
committed by GitHub
parent 16499a22ab
commit 6226381484
29 changed files with 380 additions and 138 deletions

View File

@@ -15,7 +15,7 @@ from typing import (
) )
from django.core.handlers import base as base from django.core.handlers import base as base
from django.http import HttpRequest, QueryDict from django.http.request import HttpRequest, _ImmutableQueryDict
from django.http.response import HttpResponseBase from django.http.response import HttpResponseBase
from django.urls.resolvers import ResolverMatch, URLResolver from django.urls.resolvers import ResolverMatch, URLResolver
from django.utils.datastructures import MultiValueDict from django.utils.datastructures import MultiValueDict
@@ -34,8 +34,8 @@ class ASGIRequest(HttpRequest):
META: Dict[str, Any] = ... META: Dict[str, Any] = ...
def __init__(self, scope: Mapping[str, Any], body_file: IO[bytes]) -> None: ... def __init__(self, scope: Mapping[str, Any], body_file: IO[bytes]) -> None: ...
@property @property
def GET(self) -> QueryDict: ... # type: ignore def GET(self) -> _ImmutableQueryDict: ... # type: ignore
POST: QueryDict = ... POST: _ImmutableQueryDict = ...
FILES: MultiValueDict = ... FILES: MultiValueDict = ...
@property @property
def COOKIES(self) -> Dict[str, str]: ... # type: ignore def COOKIES(self) -> Dict[str, str]: ... # type: ignore

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional from typing import Optional
from django.db.models.fields import Field from django.db.models.fields import Field
@@ -23,15 +23,15 @@ class AddField(FieldOperation):
class RemoveField(FieldOperation): ... class RemoveField(FieldOperation): ...
class AlterField(FieldOperation): class AlterField(FieldOperation):
field: Any = ... field: Field = ...
preserve_default: Any = ... preserve_default: bool = ...
def __init__(self, model_name: str, name: str, field: Field, preserve_default: bool = ...) -> None: ... def __init__(self, model_name: str, name: str, field: Field, preserve_default: bool = ...) -> None: ...
class RenameField(FieldOperation): class RenameField(FieldOperation):
old_name: Any = ... old_name: str = ...
new_name: Any = ... new_name: str = ...
def __init__(self, model_name: str, old_name: str, new_name: str) -> None: ... def __init__(self, model_name: str, old_name: str, new_name: str) -> None: ...
@property @property
def old_name_lower(self): ... def old_name_lower(self) -> str: ...
@property @property
def new_name_lower(self) -> str: ... def new_name_lower(self) -> str: ...

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union
from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.migrations.operations.base import Operation from django.db.migrations.operations.base import Operation
@@ -7,9 +7,7 @@ from django.db.models.constraints import BaseConstraint
from django.db.models.fields import Field from django.db.models.fields import Field
from django.db.models.indexes import Index from django.db.models.indexes import Index
from django.db.models.manager import Manager from django.db.models.manager import Manager
from django.utils.datastructures import _ListOrTuple from django.db.models.options import _OptionTogetherT
_T = TypeVar("_T")
class ModelOperation(Operation): class ModelOperation(Operation):
name: str = ... name: str = ...
@@ -53,10 +51,10 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
def __init__( def __init__(
self, self,
name: str, name: str,
option_value: _ListOrTuple[Tuple[str, str]], option_value: Optional[_OptionTogetherT],
) -> None: ... ) -> None: ...
@property @property
def option_value(self) -> Set[Tuple[str, str]]: ... def option_value(self) -> Optional[Set[Tuple[str, ...]]]: ...
def deconstruct(self) -> Tuple[str, Sequence[Any], Dict[str, Any]]: ... def deconstruct(self) -> Tuple[str, Sequence[Any], Dict[str, Any]]: ...
def state_forwards(self, app_label: str, state: Any) -> None: ... def state_forwards(self, app_label: str, state: Any) -> None: ...
def database_forwards( def database_forwards(
@@ -72,26 +70,26 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
class AlterUniqueTogether(AlterTogetherOptionOperation): class AlterUniqueTogether(AlterTogetherOptionOperation):
option_name: str = ... option_name: str = ...
unique_together: _ListOrTuple[Tuple[str, str]] = ... unique_together: Optional[Set[Tuple[str, ...]]] = ...
def __init__(self, name: str, unique_together: _ListOrTuple[Tuple[str, str]]) -> None: ... def __init__(self, name: str, unique_together: Optional[_OptionTogetherT]) -> None: ...
class AlterIndexTogether(AlterTogetherOptionOperation): class AlterIndexTogether(AlterTogetherOptionOperation):
option_name: str = ... option_name: str = ...
index_together: _ListOrTuple[Tuple[str, str]] = ... index_together: Optional[Set[Tuple[str, ...]]] = ...
def __init__(self, name: str, index_together: _ListOrTuple[Tuple[str, str]]) -> None: ... def __init__(self, name: str, index_together: Optional[_OptionTogetherT]) -> None: ...
class AlterOrderWithRespectTo(ModelOptionOperation): class AlterOrderWithRespectTo(ModelOptionOperation):
order_with_respect_to: str = ... order_with_respect_to: str = ...
def __init__(self, name: str, order_with_respect_to: str) -> None: ... def __init__(self, name: str, order_with_respect_to: str) -> None: ...
class AlterModelOptions(ModelOptionOperation): class AlterModelOptions(ModelOptionOperation):
ALTER_OPTION_KEYS: Any = ... ALTER_OPTION_KEYS: List[str] = ...
options: Dict[str, str] = ... options: Dict[str, Any] = ...
def __init__(self, name: str, options: Dict[str, Any]) -> None: ... def __init__(self, name: str, options: Dict[str, Any]) -> None: ...
class AlterModelManagers(ModelOptionOperation): class AlterModelManagers(ModelOptionOperation):
managers: Any = ... managers: Sequence[Manager] = ...
def __init__(self, name: Any, managers: Any) -> None: ... def __init__(self, name: str, managers: Sequence[Manager]) -> None: ...
class IndexOperation(Operation): class IndexOperation(Operation):
option_name: str = ... option_name: str = ...

View File

@@ -1,13 +1,9 @@
import sys import sys
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union from typing import Any, Mapping, Optional, Sequence, Union
from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.migrations.state import StateApps from django.db.migrations.state import StateApps
from django.utils.datastructures import _ListOrTuple
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
from .base import Operation from .base import Operation
@@ -25,14 +21,14 @@ class SeparateDatabaseAndState(Operation):
class RunSQL(Operation): class RunSQL(Operation):
noop: Literal[""] = ... noop: Literal[""] = ...
sql: Union[str, List[str], Tuple[str, ...]] = ... sql: Union[str, _ListOrTuple[str]] = ...
reverse_sql: Optional[Union[str, List[str], Tuple[str, ...]]] = ... reverse_sql: Optional[Union[str, _ListOrTuple[str]]] = ...
state_operations: Sequence[Operation] = ... state_operations: Sequence[Operation] = ...
hints: Mapping[str, Any] = ... hints: Mapping[str, Any] = ...
def __init__( def __init__(
self, self,
sql: Union[str, List[str], Tuple[str, ...]], sql: Union[str, _ListOrTuple[str]],
reverse_sql: Optional[Union[str, List[str], Tuple[str, ...]]] = ..., reverse_sql: Optional[Union[str, _ListOrTuple[str]]] = ...,
state_operations: Sequence[Operation] = ..., state_operations: Sequence[Operation] = ...,
hints: Optional[Mapping[str, Any]] = ..., hints: Optional[Mapping[str, Any]] = ...,
elidable: bool = ..., elidable: bool = ...,
@@ -44,13 +40,13 @@ class _CodeCallable(Protocol):
class RunPython(Operation): class RunPython(Operation):
code: _CodeCallable = ... code: _CodeCallable = ...
reverse_code: Optional[_CodeCallable] = ... reverse_code: Optional[_CodeCallable] = ...
hints: Optional[Dict[str, Any]] = ... hints: Mapping[str, Any] = ...
def __init__( def __init__(
self, self,
code: _CodeCallable, code: _CodeCallable,
reverse_code: Optional[_CodeCallable] = ..., reverse_code: Optional[_CodeCallable] = ...,
atomic: Optional[bool] = ..., atomic: Optional[bool] = ...,
hints: Optional[Dict[str, Any]] = ..., hints: Optional[Mapping[str, Any]] = ...,
elidable: bool = ..., elidable: bool = ...,
) -> None: ... ) -> None: ...
@staticmethod @staticmethod

View File

@@ -1,4 +1,5 @@
from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union import sys
from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union, overload
from django.apps.config import AppConfig from django.apps.config import AppConfig
from django.apps.registry import Apps from django.apps.registry import Apps
@@ -11,16 +12,26 @@ from django.db.models.fields.related import ManyToManyField, OneToOneField
from django.db.models.fields.reverse_related import ForeignObjectRel from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.manager import Manager from django.db.models.manager import Manager
from django.db.models.query_utils import PathInfo from django.db.models.query_utils import PathInfo
from django.utils.datastructures import ImmutableList from django.utils.datastructures import ImmutableList, _ListOrTuple
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
PROXY_PARENTS: object PROXY_PARENTS: object
EMPTY_RELATION_TREE: Any EMPTY_RELATION_TREE: Any
IMMUTABLE_WARNING: str IMMUTABLE_WARNING: str
DEFAULT_NAMES: Tuple[str, ...] DEFAULT_NAMES: Tuple[str, ...]
def normalize_together( _OptionTogetherT = Union[_ListOrTuple[Union[_ListOrTuple[str], str]], Set[Tuple[str, ...]]]
option_together: Union[List[Tuple[str, str]], Tuple[Tuple[str, str], ...], Tuple[()], Tuple[str, str]]
) -> Tuple[Tuple[str, str], ...]: ... @overload
def normalize_together(option_together: _ListOrTuple[Union[_ListOrTuple[str], str]]) -> Tuple[Tuple[str, ...], ...]: ...
# Any other value will be returned unchanged, but probably only set is semantically allowed
@overload
def normalize_together(option_together: Set[Tuple[str, ...]]) -> Set[Tuple[str, ...]]: ...
_T = TypeVar("_T") _T = TypeVar("_T")
@@ -45,18 +56,18 @@ class Options(Generic[_M]):
db_table: str = ... db_table: str = ...
ordering: Optional[Sequence[str]] = ... ordering: Optional[Sequence[str]] = ...
indexes: List[Any] = ... indexes: List[Any] = ...
unique_together: Union[Sequence[Tuple[str, str]], Tuple[str, str]] = ... unique_together: Sequence[Tuple[str]] = ... # Are always normalized
index_together: Union[Sequence[Tuple[str, str]], Tuple[str, str]] = ... index_together: Sequence[Tuple[str]] = ... # Are always normalized
select_on_save: bool = ... select_on_save: bool = ...
default_permissions: Sequence[str] = ... default_permissions: Sequence[str] = ...
permissions: List[Any] = ... permissions: List[Any] = ...
object_name: Optional[str] = ... object_name: Optional[str] = ...
app_label: str = ... app_label: str = ...
get_latest_by: Optional[Sequence[str]] = ... get_latest_by: Optional[Sequence[str]] = ...
order_with_respect_to: Optional[Any] = ... order_with_respect_to: Optional[str] = ...
db_tablespace: str = ... db_tablespace: str = ...
required_db_features: List[Any] = ... required_db_features: List[str] = ...
required_db_vendor: Any = ... required_db_vendor: Optional[Literal["sqlite", "postgresql", "mysql", "oracle"]] = ...
meta: Optional[type] = ... meta: Optional[type] = ...
pk: Optional[Field] = ... pk: Optional[Field] = ...
auto_field: Optional[AutoField] = ... auto_field: Optional[AutoField] = ...
@@ -105,7 +116,7 @@ class Options(Generic[_M]):
def default_manager(self) -> Optional[Manager]: ... def default_manager(self) -> Optional[Manager]: ...
@property @property
def fields(self) -> ImmutableList[Field[Any, Any]]: ... def fields(self) -> ImmutableList[Field[Any, Any]]: ...
def get_field(self, field_name: str) -> Union[Field, ForeignObjectRel]: ... def get_field(self, field_name: str) -> Union[Field, ForeignObjectRel, GenericForeignKey]: ...
def get_base_chain(self, model: Type[Model]) -> List[Type[Model]]: ... def get_base_chain(self, model: Type[Model]) -> List[Type[Model]]: ...
def get_parent_list(self) -> List[Type[Model]]: ... def get_parent_list(self) -> List[Type[Model]]: ...
def get_ancestor_link(self, ancestor: Type[Model]) -> Optional[OneToOneField]: ... def get_ancestor_link(self, ancestor: Type[Model]) -> Optional[OneToOneField]: ...
@@ -113,7 +124,7 @@ class Options(Generic[_M]):
def get_path_from_parent(self, parent: Type[Model]) -> List[PathInfo]: ... def get_path_from_parent(self, parent: Type[Model]) -> List[PathInfo]: ...
def get_fields( def get_fields(
self, include_parents: bool = ..., include_hidden: bool = ... self, include_parents: bool = ..., include_hidden: bool = ...
) -> List[Union[Field[Any, Any], ForeignObjectRel]]: ... ) -> List[Union[Field[Any, Any], ForeignObjectRel, GenericForeignKey]]: ...
@property @property
def total_unique_constraints(self) -> List[UniqueConstraint]: ... def total_unique_constraints(self) -> List[UniqueConstraint]: ...
@property @property

View File

@@ -45,7 +45,9 @@ class Q(tree.Node):
AND: str = ... AND: str = ...
OR: str = ... OR: str = ...
conditional: bool = ... conditional: bool = ...
def __init__(self, *args: Any, _connector: Optional[Any] = ..., _negated: bool = ..., **kwargs: Any) -> None: ... def __init__(self, *args: Any, **kwargs: Any) -> None: ...
# Fake signature, the real is
# def __init__(self, *args: Any, _connector: Optional[Any] = ..., _negated: bool = ..., **kwargs: Any) -> None: ...
def __or__(self, other: Q) -> Q: ... def __or__(self, other: Q) -> Q: ...
def __and__(self, other: Q) -> Q: ... def __and__(self, other: Q) -> Q: ...
def __invert__(self) -> Q: ... def __invert__(self) -> Q: ...

View File

@@ -41,6 +41,8 @@ class ConnectionRouter:
def __init__(self, routers: Optional[Iterable[Any]] = ...) -> None: ... def __init__(self, routers: Optional[Iterable[Any]] = ...) -> None: ...
@property @property
def routers(self) -> List[Any]: ... def routers(self) -> List[Any]: ...
def db_for_read(self, model: Type[Model], **hints: Any) -> str: ...
def db_for_write(self, model: Type[Model], **hints: Any) -> str: ...
def allow_relation(self, obj1: Model, obj2: Model, **hints: Any) -> bool: ... def allow_relation(self, obj1: Model, obj2: Model, **hints: Any) -> bool: ...
def allow_migrate(self, db: str, app_label: str, **hints: Any) -> bool: ... def allow_migrate(self, db: str, app_label: str, **hints: Any) -> bool: ...
def allow_migrate_model(self, db: str, model: Type[Model]) -> bool: ... def allow_migrate_model(self, db: str, model: Type[Model]) -> bool: ...

View File

@@ -3,11 +3,12 @@ from datetime import datetime
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.files import File from django.core.files.uploadedfile import UploadedFile
from django.utils.datastructures import MultiValueDict
from django.utils.safestring import SafeString from django.utils.safestring import SafeString
_DataT = Mapping[str, Any] _DataT = Mapping[str, Any]
_FilesT = Mapping[str, Iterable[File]] _FilesT = MultiValueDict[str, UploadedFile]
def pretty_name(name: str) -> str: ... def pretty_name(name: str) -> str: ...
def flatatt(attrs: Dict[str, Any]) -> SafeString: ... def flatatt(attrs: Dict[str, Any]) -> SafeString: ...

View File

@@ -1,3 +1,4 @@
import sys
from io import BytesIO from io import BytesIO
from typing import ( from typing import (
Any, Any,
@@ -6,6 +7,7 @@ from typing import (
Iterable, Iterable,
List, List,
Mapping, Mapping,
NoReturn,
Optional, Optional,
Pattern, Pattern,
Set, Set,
@@ -24,6 +26,11 @@ from django.core.files import uploadedfile, uploadhandler
from django.urls import ResolverMatch from django.urls import ResolverMatch
from django.utils.datastructures import CaseInsensitiveMapping, ImmutableList, MultiValueDict from django.utils.datastructures import CaseInsensitiveMapping, ImmutableList, MultiValueDict
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
RAISE_ERROR: object = ... RAISE_ERROR: object = ...
host_validation_re: Pattern[str] = ... host_validation_re: Pattern[str] = ...
@@ -40,8 +47,8 @@ class HttpHeaders(CaseInsensitiveMapping[str]):
def parse_header_name(cls, header: str) -> Optional[str]: ... def parse_header_name(cls, header: str) -> Optional[str]: ...
class HttpRequest(BytesIO): class HttpRequest(BytesIO):
GET: QueryDict = ... GET: _ImmutableQueryDict = ...
POST: QueryDict = ... POST: _ImmutableQueryDict = ...
COOKIES: Dict[str, str] = ... COOKIES: Dict[str, str] = ...
META: Dict[str, Any] = ... META: Dict[str, Any] = ...
FILES: MultiValueDict[str, uploadedfile.UploadedFile] = ... FILES: MultiValueDict[str, uploadedfile.UploadedFile] = ...
@@ -55,7 +62,15 @@ class HttpRequest(BytesIO):
site: Site site: Site
session: SessionBase session: SessionBase
_stream: BinaryIO _stream: BinaryIO
def __init__(self) -> None: ... # The magic. If we instantiate HttpRequest directly somewhere, it has
# mutable GET and POST. However, both ASGIRequest and WSGIRequest have immutable,
# so when we use HttpRequest to refer to any of them we want exactly this.
# Case when some function creates *exactly* HttpRequest (not subclass)
# remain uncovered, however it's probably the best solution we can afford.
def __new__(cls) -> _MutableHttpRequest: ...
# When both __init__ and __new__ are present, mypy will prefer __init__
# (see comments in mypy.checkmember.type_object_type)
# def __init__(self) -> None: ...
def get_host(self) -> str: ... def get_host(self) -> str: ...
def get_port(self) -> str: ... def get_port(self) -> str: ...
def get_full_path(self, force_append_slash: bool = ...) -> str: ... def get_full_path(self, force_append_slash: bool = ...) -> str: ...
@@ -90,17 +105,43 @@ class HttpRequest(BytesIO):
def _load_post_and_files(self) -> None: ... def _load_post_and_files(self) -> None: ...
def accepts(self, media_type: str) -> bool: ... def accepts(self, media_type: str) -> bool: ...
class _MutableHttpRequest(HttpRequest):
GET: QueryDict = ... # type: ignore[assignment]
POST: QueryDict = ... # type: ignore[assignment]
_Q = TypeVar("_Q", bound="QueryDict") _Q = TypeVar("_Q", bound="QueryDict")
_Z = TypeVar("_Z")
class QueryDict(MultiValueDict[str, str]): class QueryDict(MultiValueDict[str, str]):
_mutable: bool = ... _mutable: bool = ...
# We can make it mutable only by specifying `mutable=True`.
# It can be done a) with kwarg and b) with pos. arg. `overload` has
# some problems with args/kwargs + Literal, so two signatures are required.
# ('querystring', True, [...])
@overload
def __init__( def __init__(
self, query_string: Optional[Union[str, bytes]] = ..., mutable: bool = ..., encoding: Optional[str] = ... self: QueryDict,
query_string: Optional[Union[str, bytes]],
mutable: Literal[True],
encoding: Optional[str] = ...,
) -> None: ...
# ([querystring='string',] mutable=True, [...])
@overload
def __init__(
self: QueryDict,
*,
mutable: Literal[True],
query_string: Optional[Union[str, bytes]] = ...,
encoding: Optional[str] = ...,
) -> None: ...
# Otherwise it's immutable
@overload
def __init__( # type: ignore[misc]
self: _ImmutableQueryDict,
query_string: Optional[Union[str, bytes]] = ...,
mutable: bool = ...,
encoding: Optional[str] = ...,
) -> None: ... ) -> None: ...
def setlist(self, key: Union[str, bytes], list_: Iterable[Union[str, bytes]]) -> None: ...
def setlistdefault(self, key: Union[str, bytes], default_list: Optional[List[str]] = ...) -> List[str]: ...
def appendlist(self, key: Union[str, bytes], value: Union[str, bytes]) -> None: ...
def urlencode(self, safe: Optional[str] = ...) -> str: ...
@classmethod @classmethod
def fromkeys( # type: ignore def fromkeys( # type: ignore
cls: Type[_Q], cls: Type[_Q],
@@ -114,7 +155,46 @@ class QueryDict(MultiValueDict[str, str]):
@encoding.setter @encoding.setter
def encoding(self, value: str) -> None: ... def encoding(self, value: str) -> None: ...
def __setitem__(self, key: Union[str, bytes], value: Union[str, bytes]) -> None: ... def __setitem__(self, key: Union[str, bytes], value: Union[str, bytes]) -> None: ...
def __delitem__(self, key: Union[str, bytes]) -> None: ...
def setlist(self, key: Union[str, bytes], list_: Iterable[Union[str, bytes]]) -> None: ...
def setlistdefault(self, key: Union[str, bytes], default_list: Optional[List[str]] = ...) -> List[str]: ...
def appendlist(self, key: Union[str, bytes], value: Union[str, bytes]) -> None: ...
# Fake signature (because *args is used in source, but it fails with more that 1 argument)
@overload
def pop(self, key: Union[str, bytes], /) -> str: ...
@overload
def pop(self, key: Union[str, bytes], default: Union[str, _Z] = ..., /) -> Union[str, _Z]: ...
def popitem(self) -> Tuple[str, str]: ...
def clear(self) -> None: ...
def setdefault(self, key: Union[str, bytes], default: Union[str, bytes, None] = ...) -> str: ... def setdefault(self, key: Union[str, bytes], default: Union[str, bytes, None] = ...) -> str: ...
def copy(self) -> QueryDict: ...
def urlencode(self, safe: Optional[str] = ...) -> str: ...
class _ImmutableQueryDict(QueryDict):
_mutable: Literal[False]
# def __init__(
# self, query_string: Optional[Union[str, bytes]] = ..., mutable: bool = ..., encoding: Optional[str] = ...
# ) -> None: ...
def __setitem__(self, key: Union[str, bytes], value: Union[str, bytes]) -> NoReturn: ...
def __delitem__(self, key: Union[str, bytes]) -> NoReturn: ...
def setlist(self, key: Union[str, bytes], list_: Iterable[Union[str, bytes]]) -> NoReturn: ...
def setlistdefault(self, key: Union[str, bytes], default_list: Optional[List[str]] = ...) -> NoReturn: ...
def appendlist(self, key: Union[str, bytes], value: Union[str, bytes]) -> NoReturn: ...
# Fake signature (because *args is used in source, but it fails with more that 1 argument)
@overload
def pop(self, key: Union[str, bytes], /) -> NoReturn: ...
@overload
def pop(self, key: Union[str, bytes], default: Union[str, _Z] = ..., /) -> NoReturn: ...
def popitem(self) -> NoReturn: ...
def clear(self) -> NoReturn: ...
def setdefault(self, key: Union[str, bytes], default: Union[str, bytes, None] = ...) -> NoReturn: ...
def copy(self) -> QueryDict: ... # type: ignore[override]
def urlencode(self, safe: Optional[str] = ...) -> str: ...
# Fakes for convenience (for `request.GET` and `request.POST`). If dict
# was created by Django, there is no chance to hit `List[object]` (empty list)
# edge case.
def __getitem__(self, key: str) -> str: ...
def dict(self) -> Dict[str, str]: ... # type: ignore[override]
class MediaType: class MediaType:
def __init__(self, media_type_raw_line: str) -> None: ... def __init__(self, media_type_raw_line: str) -> None: ...

View File

@@ -1,13 +1,9 @@
import datetime import datetime
from io import BytesIO from io import BytesIO
from json import JSONEncoder from json import JSONEncoder
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, overload from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, overload, type_check_only
from django.core.handlers.wsgi import WSGIRequest
from django.http.cookie import SimpleCookie from django.http.cookie import SimpleCookie
from django.template import Context, Template
from django.test.client import Client
from django.urls import ResolverMatch
from django.utils.datastructures import CaseInsensitiveMapping, _PropertyDescriptor from django.utils.datastructures import CaseInsensitiveMapping, _PropertyDescriptor
class BadHeaderError(ValueError): ... class BadHeaderError(ValueError): ...
@@ -82,6 +78,12 @@ class HttpResponseBase:
def writable(self) -> bool: ... def writable(self) -> bool: ...
def writelines(self, lines: Iterable[object]) -> None: ... def writelines(self, lines: Iterable[object]) -> None: ...
# Fake methods that are implemented by all subclasses
@type_check_only
def __iter__(self) -> Iterator[bytes]: ...
@type_check_only
def getvalue(self) -> bytes: ...
class HttpResponse(HttpResponseBase, Iterable[bytes]): class HttpResponse(HttpResponseBase, Iterable[bytes]):
content = _PropertyDescriptor[object, bytes]() content = _PropertyDescriptor[object, bytes]()
csrf_cookie_set: bool csrf_cookie_set: bool
@@ -94,14 +96,6 @@ class HttpResponse(HttpResponseBase, Iterable[bytes]):
def serialize(self) -> bytes: ... def serialize(self) -> bytes: ...
__bytes__ = serialize __bytes__ = serialize
def __iter__(self) -> Iterator[bytes]: ... def __iter__(self) -> Iterator[bytes]: ...
# Attributes assigned by monkey-patching in test client ClientHandler.__call__()
wsgi_request: WSGIRequest
# Attributes assigned by monkey-patching in test client Client.request()
client: Client
request: Dict[str, Any]
templates: List[Template]
context: Context
resolver_match: ResolverMatch
def getvalue(self) -> bytes: ... def getvalue(self) -> bytes: ...
class StreamingHttpResponse(HttpResponseBase, Iterable[bytes]): class StreamingHttpResponse(HttpResponseBase, Iterable[bytes]):
@@ -111,13 +105,7 @@ class StreamingHttpResponse(HttpResponseBase, Iterable[bytes]):
def getvalue(self) -> bytes: ... def getvalue(self) -> bytes: ...
class FileResponse(StreamingHttpResponse): class FileResponse(StreamingHttpResponse):
client: Client
context: None
file_to_stream: Optional[BytesIO] file_to_stream: Optional[BytesIO]
request: Dict[str, str]
resolver_match: ResolverMatch
templates: List[Any]
wsgi_request: WSGIRequest
block_size: int = ... block_size: int = ...
as_attachment: bool = ... as_attachment: bool = ...
filename: str = ... filename: str = ...

View File

@@ -3,6 +3,7 @@ from json import JSONEncoder
from types import TracebackType from types import TracebackType
from typing import ( from typing import (
Any, Any,
Awaitable,
Callable, Callable,
Dict, Dict,
Generic, Generic,
@@ -26,6 +27,8 @@ from django.core.handlers.wsgi import WSGIRequest
from django.http.cookie import SimpleCookie from django.http.cookie import SimpleCookie
from django.http.request import HttpRequest from django.http.request import HttpRequest
from django.http.response import HttpResponseBase from django.http.response import HttpResponseBase
from django.template.base import Template
from django.urls import ResolverMatch
BOUNDARY: str = ... BOUNDARY: str = ...
MULTIPART_CONTENT: str = ... MULTIPART_CONTENT: str = ...
@@ -49,27 +52,25 @@ _T = TypeVar("_T")
def closing_iterator_wrapper(iterable: Iterable[_T], close: Callable[[], Any]) -> Iterator[_T]: ... def closing_iterator_wrapper(iterable: Iterable[_T], close: Callable[[], Any]) -> Iterator[_T]: ...
def conditional_content_removal(request: HttpRequest, response: HttpResponseBase) -> HttpResponseBase: ... def conditional_content_removal(request: HttpRequest, response: HttpResponseBase) -> HttpResponseBase: ...
class _WSGIResponse(HttpResponseBase):
wsgi_request: WSGIRequest
class _ASGIResponse(HttpResponseBase):
asgi_request: ASGIRequest
class ClientHandler(BaseHandler): class ClientHandler(BaseHandler):
enforce_csrf_checks: bool = ... enforce_csrf_checks: bool = ...
def __init__(self, enforce_csrf_checks: bool = ..., *args: Any, **kwargs: Any) -> None: ... def __init__(self, enforce_csrf_checks: bool = ..., *args: Any, **kwargs: Any) -> None: ...
def __call__(self, environ: Dict[str, Any]) -> HttpResponseBase: ... def __call__(self, environ: Dict[str, Any]) -> _WSGIResponse: ...
class AsyncClientHandler(BaseHandler): class AsyncClientHandler(BaseHandler):
enforce_csrf_checks: bool = ... enforce_csrf_checks: bool = ...
def __init__(self, enforce_csrf_checks: bool = ..., *args: Any, **kwargs: Any) -> None: ... def __init__(self, enforce_csrf_checks: bool = ..., *args: Any, **kwargs: Any) -> None: ...
async def __call__(self, scope: Dict[str, Any]) -> HttpResponseBase: ... async def __call__(self, scope: Dict[str, Any]) -> _ASGIResponse: ...
def encode_multipart(boundary: str, data: Dict[str, Any]) -> bytes: ... def encode_multipart(boundary: str, data: Dict[str, Any]) -> bytes: ...
def encode_file(boundary: str, key: str, file: Any) -> List[bytes]: ... def encode_file(boundary: str, key: str, file: Any) -> List[bytes]: ...
# fake to distinguish WSGIRequest and ASGIRequest
_R = TypeVar("_R", bound=HttpRequest)
class _MonkeyPatchedHttpResponseBase(Generic[_R], HttpResponseBase):
def json(self) -> Any: ...
wsgi_request: _R
class _RequestFactory(Generic[_T]): class _RequestFactory(Generic[_T]):
json_encoder: Type[JSONEncoder] json_encoder: Type[JSONEncoder]
defaults: Dict[str, str] defaults: Dict[str, str]
@@ -102,9 +103,9 @@ class _RequestFactory(Generic[_T]):
**extra: Any **extra: Any
) -> _T: ... ) -> _T: ...
class RequestFactory(_RequestFactory[_T]): ... class RequestFactory(_RequestFactory[WSGIRequest]): ...
class AsyncRequestFactory(_RequestFactory[_T]): class _AsyncRequestFactory(_RequestFactory[_T]):
def request(self, **request: Any) -> _T: ... def request(self, **request: Any) -> _T: ...
def generic( def generic(
self, self,
@@ -116,6 +117,25 @@ class AsyncRequestFactory(_RequestFactory[_T]):
**extra: Any **extra: Any
) -> _T: ... ) -> _T: ...
class AsyncRequestFactory(_AsyncRequestFactory[ASGIRequest]): ...
# fakes to distinguish WSGIRequest and ASGIRequest
class _MonkeyPatchedWSGIResponse(_WSGIResponse):
def json(self) -> Any: ...
request: Dict[str, Any]
client: Client
templates: List[Template]
context: List[Dict[str, Any]]
resolver_match: ResolverMatch
class _MonkeyPatchedASGIResponse(_ASGIResponse):
def json(self) -> Any: ...
request: Dict[str, Any]
client: AsyncClient
templates: List[Template]
context: List[Dict[str, Any]]
resolver_match: ResolverMatch
class ClientMixin: class ClientMixin:
def store_exc_info(self, **kwargs: Any) -> None: ... def store_exc_info(self, **kwargs: Any) -> None: ...
def check_exception(self, response: HttpResponseBase) -> NoReturn: ... def check_exception(self, response: HttpResponseBase) -> NoReturn: ...
@@ -125,7 +145,7 @@ class ClientMixin:
def force_login(self, user: AbstractBaseUser, backend: Optional[str] = ...) -> None: ... def force_login(self, user: AbstractBaseUser, backend: Optional[str] = ...) -> None: ...
def logout(self) -> None: ... def logout(self) -> None: ...
class Client(ClientMixin, RequestFactory[_MonkeyPatchedHttpResponseBase[WSGIRequest]]): class Client(ClientMixin, _RequestFactory[_MonkeyPatchedWSGIResponse]):
handler: ClientHandler handler: ClientHandler
raise_request_exception: bool raise_request_exception: bool
exc_info: Optional[Tuple[Type[BaseException], BaseException, TracebackType]] exc_info: Optional[Tuple[Type[BaseException], BaseException, TracebackType]]
@@ -133,19 +153,19 @@ class Client(ClientMixin, RequestFactory[_MonkeyPatchedHttpResponseBase[WSGIRequ
self, enforce_csrf_checks: bool = ..., raise_request_exception: bool = ..., **defaults: Any self, enforce_csrf_checks: bool = ..., raise_request_exception: bool = ..., **defaults: Any
) -> None: ... ) -> None: ...
# Silence type warnings, since this class overrides arguments and return types in an unsafe manner. # Silence type warnings, since this class overrides arguments and return types in an unsafe manner.
def request(self, **request: Any) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ... def request(self, **request: Any) -> _MonkeyPatchedWSGIResponse: ...
def get( # type: ignore def get( # type: ignore
self, path: str, data: Any = ..., follow: bool = ..., secure: bool = ..., **extra: Any self, path: str, data: Any = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ... ) -> _MonkeyPatchedWSGIResponse: ...
def post( # type: ignore def post( # type: ignore
self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ... ) -> _MonkeyPatchedWSGIResponse: ...
def head( # type: ignore def head( # type: ignore
self, path: str, data: Any = ..., follow: bool = ..., secure: bool = ..., **extra: Any self, path: str, data: Any = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ... ) -> _MonkeyPatchedWSGIResponse: ...
def trace( # type: ignore def trace( # type: ignore
self, path: str, data: Any = ..., follow: bool = ..., secure: bool = ..., **extra: Any self, path: str, data: Any = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ... ) -> _MonkeyPatchedWSGIResponse: ...
def options( # type: ignore def options( # type: ignore
self, self,
path: str, path: str,
@@ -154,18 +174,18 @@ class Client(ClientMixin, RequestFactory[_MonkeyPatchedHttpResponseBase[WSGIRequ
follow: bool = ..., follow: bool = ...,
secure: bool = ..., secure: bool = ...,
**extra: Any **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ... ) -> _MonkeyPatchedWSGIResponse: ...
def put( # type: ignore def put( # type: ignore
self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ... ) -> _MonkeyPatchedWSGIResponse: ...
def patch( # type: ignore def patch( # type: ignore
self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ... ) -> _MonkeyPatchedWSGIResponse: ...
def delete( # type: ignore def delete( # type: ignore
self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ... ) -> _MonkeyPatchedWSGIResponse: ...
class AsyncClient(ClientMixin, AsyncRequestFactory[_MonkeyPatchedHttpResponseBase[ASGIRequest]]): class AsyncClient(ClientMixin, _AsyncRequestFactory[Awaitable[_MonkeyPatchedASGIResponse]]):
handler: AsyncClientHandler handler: AsyncClientHandler
raise_request_exception: bool raise_request_exception: bool
exc_info: Any exc_info: Any
@@ -173,4 +193,4 @@ class AsyncClient(ClientMixin, AsyncRequestFactory[_MonkeyPatchedHttpResponseBas
def __init__( def __init__(
self, enforce_csrf_checks: bool = ..., raise_request_exception: bool = ..., **defaults: Any self, enforce_csrf_checks: bool = ..., raise_request_exception: bool = ..., **defaults: Any
) -> None: ... ) -> None: ...
async def request(self, **request: Any) -> _MonkeyPatchedHttpResponseBase[ASGIRequest]: ... # type: ignore async def request(self, **request: Any) -> _MonkeyPatchedASGIResponse: ...

View File

@@ -28,7 +28,7 @@ from django.db.models.query import QuerySet, RawQuerySet
from django.forms.fields import EmailField from django.forms.fields import EmailField
from django.http.response import HttpResponse, HttpResponseBase from django.http.response import HttpResponse, HttpResponseBase
from django.template.base import Template from django.template.base import Template
from django.test.client import Client from django.test.client import AsyncClient, Client
from django.test.html import Element from django.test.html import Element
from django.test.utils import CaptureQueriesContext, ContextList from django.test.utils import CaptureQueriesContext, ContextList
from django.utils.functional import classproperty from django.utils.functional import classproperty
@@ -64,8 +64,10 @@ class _DatabaseFailure:
def __call__(self) -> None: ... def __call__(self) -> None: ...
class SimpleTestCase(unittest.TestCase): class SimpleTestCase(unittest.TestCase):
client_class: Any = ... client_class: Type[Client] = ...
client: Client client: Client
async_client_class: Type[AsyncClient] = ...
async_client: AsyncClient
allow_database_queries: bool = ... allow_database_queries: bool = ...
# TODO: str -> Literal['__all__'] # TODO: str -> Literal['__all__']
databases: Union[Set[str], str] = ... databases: Union[Set[str], str] = ...
@@ -142,9 +144,7 @@ class SimpleTestCase(unittest.TestCase):
) -> Any: ... ) -> Any: ...
def assertHTMLEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ... def assertHTMLEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ...
def assertHTMLNotEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ... def assertHTMLNotEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ...
def assertInHTML( def assertInHTML(self, needle: str, haystack: str, count: Optional[int] = ..., msg_prefix: str = ...) -> None: ...
self, needle: str, haystack: SafeString, count: Optional[int] = ..., msg_prefix: str = ...
) -> None: ...
def assertJSONEqual( def assertJSONEqual(
self, self,
raw: str, raw: str,

View File

@@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Pattern, Tuple, Type, Union, overload from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Pattern, Sequence, Tuple, Type, Union, overload
from django.core.checks.messages import CheckMessage from django.core.checks.messages import CheckMessage
from django.urls.converters import UUIDConverter from django.urls.converters import UUIDConverter
@@ -93,7 +94,7 @@ class URLPattern:
class URLResolver: class URLResolver:
pattern: LocalePrefixPattern = ... pattern: LocalePrefixPattern = ...
urlconf_name: Optional[str] = ... urlconf_name: Union[str, None, Sequence[Union[URLPattern, URLResolver]]] = ...
callback: None = ... callback: None = ...
default_kwargs: Dict[str, Any] = ... default_kwargs: Dict[str, Any] = ...
namespace: Optional[str] = ... namespace: Optional[str] = ...
@@ -103,7 +104,7 @@ class URLResolver:
def __init__( def __init__(
self, self,
pattern: LocalePrefixPattern, pattern: LocalePrefixPattern,
urlconf_name: Optional[str], urlconf_name: Union[str, None, Sequence[Union[URLPattern, URLResolver]]],
default_kwargs: Optional[Dict[str, Any]] = ..., default_kwargs: Optional[Dict[str, Any]] = ...,
app_name: Optional[str] = ..., app_name: Optional[str] = ...,
namespace: Optional[str] = ..., namespace: Optional[str] = ...,
@@ -115,7 +116,7 @@ class URLResolver:
@property @property
def app_dict(self) -> Dict[str, List[str]]: ... def app_dict(self) -> Dict[str, List[str]]: ...
@property @property
def urlconf_module(self) -> Optional[List[Tuple[str, Callable]]]: ... def urlconf_module(self) -> Union[ModuleType, None, Sequence[Union[URLPattern, URLResolver]]]: ...
@property @property
def url_patterns(self) -> List[Union[URLPattern, URLResolver]]: ... def url_patterns(self) -> List[Union[URLPattern, URLResolver]]: ...
def resolve(self, path: str) -> ResolverMatch: ... def resolve(self, path: str) -> ResolverMatch: ...

View File

@@ -72,7 +72,7 @@ class classproperty(Generic[_Get]):
fget: Optional[Callable[[_Self], _Get]] = ... fget: Optional[Callable[[_Self], _Get]] = ...
def __init__(self, method: Optional[Callable[[_Self], _Get]] = ...) -> None: ... def __init__(self, method: Optional[Callable[[_Self], _Get]] = ...) -> None: ...
def __get__(self, instance: Optional[_Self], cls: Type[_Self] = ...) -> _Get: ... def __get__(self, instance: Optional[_Self], cls: Type[_Self] = ...) -> _Get: ...
def getter(self, method: Callable[[Type[_Self]], _Get]) -> classproperty[_Get]: ... def getter(self, method: Callable[[_Self], _Get]) -> classproperty[_Get]: ...
class _Getter(Protocol[_Get]): class _Getter(Protocol[_Get]):
"""Type fake to declare some read-only properties (until `property` builtin is generic) """Type fake to declare some read-only properties (until `property` builtin is generic)

View File

@@ -26,7 +26,7 @@ class FormMixin(Generic[_FormT], ContextMixin):
def get_initial(self) -> Dict[str, Any]: ... def get_initial(self) -> Dict[str, Any]: ...
def get_prefix(self) -> Optional[str]: ... def get_prefix(self) -> Optional[str]: ...
def get_form_class(self) -> Type[_FormT]: ... def get_form_class(self) -> Type[_FormT]: ...
def get_form(self, form_class: Optional[Type[_FormT]] = ...) -> BaseForm: ... def get_form(self, form_class: Optional[Type[_FormT]] = ...) -> _FormT: ...
def get_form_kwargs(self) -> Dict[str, Any]: ... def get_form_kwargs(self) -> Dict[str, Any]: ...
def get_success_url(self) -> str: ... def get_success_url(self) -> str: ...
def form_valid(self, form: _FormT) -> HttpResponse: ... def form_valid(self, form: _FormT) -> HttpResponse: ...

View File

@@ -1,5 +1,10 @@
from typing import Any, Protocol import sys
from typing import Any
if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
# Used internally by mypy_django_plugin. # Used internally by mypy_django_plugin.
class AnyAttrAllowed(Protocol): class AnyAttrAllowed(Protocol):

View File

@@ -9,6 +9,7 @@ warn_no_return = False
warn_unused_ignores = True warn_unused_ignores = True
warn_redundant_casts = True warn_redundant_casts = True
warn_unused_configs = True warn_unused_configs = True
warn_unreachable = True
plugins = plugins =
mypy_django_plugin.main mypy_django_plugin.main

View File

@@ -35,6 +35,7 @@ except ImportError:
if TYPE_CHECKING: if TYPE_CHECKING:
from django.apps.registry import Apps # noqa: F401 from django.apps.registry import Apps # noqa: F401
from django.conf import LazySettings # noqa: F401 from django.conf import LazySettings # noqa: F401
from django.contrib.contenttypes.fields import GenericForeignKey
@contextmanager @contextmanager
@@ -100,9 +101,6 @@ class DjangoContext:
@cached_property @cached_property
def model_modules(self) -> Dict[str, Set[Type[Model]]]: def model_modules(self) -> Dict[str, Set[Type[Model]]]:
"""All modules that contain Django models.""" """All modules that contain Django models."""
if self.apps_registry is None:
return {}
modules: Dict[str, Set[Type[Model]]] = defaultdict(set) modules: Dict[str, Set[Type[Model]]] = defaultdict(set)
for concrete_model_cls in self.apps_registry.get_models(): for concrete_model_cls in self.apps_registry.get_models():
modules[concrete_model_cls.__module__].add(concrete_model_cls) modules[concrete_model_cls.__module__].add(concrete_model_cls)
@@ -327,7 +325,7 @@ class DjangoContext:
related_model_cls = field.field.model related_model_cls = field.field.model
if isinstance(related_model_cls, str): if isinstance(related_model_cls, str):
if related_model_cls == "self": if related_model_cls == "self": # type: ignore[unreachable]
# same model # same model
related_model_cls = field.model related_model_cls = field.model
elif "." not in related_model_cls: elif "." not in related_model_cls:
@@ -343,7 +341,7 @@ class DjangoContext:
self, field_parts: Iterable[str], model_cls: Type[Model] self, field_parts: Iterable[str], model_cls: Type[Model]
) -> Union[Field, ForeignObjectRel]: ) -> Union[Field, ForeignObjectRel]:
currently_observed_model = model_cls currently_observed_model = model_cls
field: Union[Field, ForeignObjectRel, None] = None field: Union[Field, ForeignObjectRel, GenericForeignKey, None] = None
for field_part in field_parts: for field_part in field_parts:
if field_part == "pk": if field_part == "pk":
field = self.get_primary_key_field(currently_observed_model) field = self.get_primary_key_field(currently_observed_model)
@@ -359,15 +357,16 @@ class DjangoContext:
if isinstance(field, ForeignObjectRel): if isinstance(field, ForeignObjectRel):
currently_observed_model = field.related_model currently_observed_model = field.related_model
assert field is not None # Guaranteed by `query.solve_lookup_type` before.
assert isinstance(field, (Field, ForeignObjectRel))
return field return field
def resolve_lookup_into_field(self, model_cls: Type[Model], lookup: str) -> Union[Field, ForeignObjectRel]: def resolve_lookup_into_field(self, model_cls: Type[Model], lookup: str) -> Union[Field, ForeignObjectRel]:
query = Query(model_cls) query = Query(model_cls)
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup) lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
if lookup_parts: if lookup_parts:
raise LookupsAreUnsupported() raise LookupsAreUnsupported()
return self._resolve_field_from_parts(field_parts, model_cls) return self._resolve_field_from_parts(field_parts, model_cls)
def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType: def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType:

View File

@@ -35,6 +35,7 @@ RELATED_FIELDS_CLASSES = {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME, MANYTOM
MIGRATION_CLASS_FULLNAME = "django.db.migrations.migration.Migration" MIGRATION_CLASS_FULLNAME = "django.db.migrations.migration.Migration"
OPTIONS_CLASS_FULLNAME = "django.db.models.options.Options" OPTIONS_CLASS_FULLNAME = "django.db.models.options.Options"
HTTPREQUEST_CLASS_FULLNAME = "django.http.request.HttpRequest" HTTPREQUEST_CLASS_FULLNAME = "django.http.request.HttpRequest"
QUERYDICT_CLASS_FULLNAME = "django.http.request.QueryDict"
COMBINABLE_EXPRESSION_FULLNAME = "django.db.models.expressions.Combinable" COMBINABLE_EXPRESSION_FULLNAME = "django.db.models.expressions.Combinable"
F_EXPRESSION_FULLNAME = "django.db.models.expressions.F" F_EXPRESSION_FULLNAME = "django.db.models.expressions.F"

View File

@@ -32,6 +32,7 @@ from mypy_django_plugin.transformers.models import (
process_model_class, process_model_class,
set_auth_user_model_boolean_fields, set_auth_user_model_boolean_fields,
) )
from mypy_django_plugin.transformers.request import check_querydict_is_mutable
def transform_model_class(ctx: ClassDefContext, django_context: DjangoContext) -> None: def transform_model_class(ctx: ClassDefContext, django_context: DjangoContext) -> None:
@@ -187,12 +188,21 @@ class NewSemanalDjangoPlugin(Plugin):
def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], MypyType]]: def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], MypyType]]:
class_fullname, _, method_name = fullname.rpartition(".") class_fullname, _, method_name = fullname.rpartition(".")
if method_name == "get_form_class": # It is looked up very often, specialcase this method for minor speed up
if method_name == "__init_subclass__":
return None
if class_fullname.endswith("QueryDict"):
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYDICT_CLASS_FULLNAME):
return partial(check_querydict_is_mutable, django_context=self.django_context)
elif method_name == "get_form_class":
info = self._get_typeinfo_or_none(class_fullname) info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME): if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return forms.extract_proper_type_for_get_form_class return forms.extract_proper_type_for_get_form_class
if method_name == "get_form": elif method_name == "get_form":
info = self._get_typeinfo_or_none(class_fullname) info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME): if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return forms.extract_proper_type_for_get_form return forms.extract_proper_type_for_get_form
@@ -204,30 +214,30 @@ class NewSemanalDjangoPlugin(Plugin):
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes: if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context) return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
if method_name == "values_list": elif method_name == "values_list":
info = self._get_typeinfo_or_none(class_fullname) info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes: if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context) return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
if method_name == "annotate": elif method_name == "annotate":
info = self._get_typeinfo_or_none(class_fullname) info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes: if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
return partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context) return partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context)
if method_name == "get_field": elif method_name == "get_field":
info = self._get_typeinfo_or_none(class_fullname) info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME): if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context) return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)
if class_fullname in manager_classes and method_name == "create": elif class_fullname in manager_classes and method_name == "create":
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context) return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
if class_fullname in manager_classes and method_name in {"filter", "get", "exclude"}: elif class_fullname in manager_classes and method_name in {"filter", "get", "exclude"}:
return partial( return partial(
mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter, mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter,
django_context=self.django_context, django_context=self.django_context,
) )
if method_name == "from_queryset": elif method_name == "from_queryset":
info = self._get_typeinfo_or_none(class_fullname) info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME): if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
return fail_if_manager_type_created_in_model_body return fail_if_manager_type_created_in_model_body

View File

@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
from django.db.models.fields import AutoField, Field from django.db.models.fields import AutoField, Field
from django.db.models.fields.related import RelatedField from django.db.models.fields.related import RelatedField
@@ -12,10 +12,13 @@ from mypy.types import TypeOfAny, UnionType
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, helpers
if TYPE_CHECKING:
from django.contrib.contenttypes.fields import GenericForeignKey
def _get_current_field_from_assignment( def _get_current_field_from_assignment(
ctx: FunctionContext, django_context: DjangoContext ctx: FunctionContext, django_context: DjangoContext
) -> Optional[Union[Field, ForeignObjectRel]]: ) -> Optional[Union[Field, ForeignObjectRel, "GenericForeignKey"]]:
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
if outer_model_info is None or not helpers.is_model_subclass_info(outer_model_info, django_context): if outer_model_info is None or not helpers.is_model_subclass_info(outer_model_info, django_context):
return None return None

View File

@@ -263,8 +263,6 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
semanal_api.defer() semanal_api.defer()
return None return None
original_return_type = method_type.ret_type original_return_type = method_type.ret_type
if original_return_type is None:
continue
# Skip any method that doesn't return _QS # Skip any method that doesn't return _QS
original_return_type = get_proper_type(original_return_type) original_return_type = get_proper_type(original_return_type)

View File

@@ -15,7 +15,7 @@ from mypy.types import TypedDictType, TypeOfAny
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME, MODEL_CLASS_FULLNAME
from mypy_django_plugin.lib.helpers import add_new_class_for_module from mypy_django_plugin.lib.helpers import add_new_class_for_module
from mypy_django_plugin.transformers import fields from mypy_django_plugin.transformers import fields
from mypy_django_plugin.transformers.fields import get_field_descriptor_types from mypy_django_plugin.transformers.fields import get_field_descriptor_types
@@ -475,8 +475,8 @@ def handle_annotated_type(ctx: AnalyzeTypeContext, django_context: DjangoContext
type_arg = ctx.api.analyze_type(args[0]) type_arg = ctx.api.analyze_type(args[0])
api = cast(SemanticAnalyzer, ctx.api.api) # type: ignore api = cast(SemanticAnalyzer, ctx.api.api) # type: ignore
if not isinstance(type_arg, Instance): if not isinstance(type_arg, Instance) or not type_arg.type.has_base(MODEL_CLASS_FULLNAME):
return ctx.api.analyze_type(ctx.type) return type_arg
fields_dict = None fields_dict = None
if len(args) > 1: if len(args) > 1:

View File

@@ -1,7 +1,7 @@
from mypy.plugin import AttributeContext from mypy.plugin import AttributeContext, MethodContext
from mypy.types import Instance from mypy.types import Instance
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy.types import UnionType from mypy.types import UninhabitedType, UnionType
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import helpers from mypy_django_plugin.lib import helpers
@@ -35,3 +35,10 @@ def set_auth_user_model_as_type_for_request_user(ctx: AttributeContext, django_c
return ctx.default_attr_type return ctx.default_attr_type
return UnionType([Instance(user_info, []), Instance(anonymous_user_info, [])]) return UnionType([Instance(user_info, []), Instance(anonymous_user_info, [])])
def check_querydict_is_mutable(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
ret_type = ctx.default_return_type
if isinstance(ret_type, UninhabitedType):
ctx.api.fail("This QueryDict is immutable.", ctx.context)
return ret_type

View File

@@ -0,0 +1,33 @@
- case: client_methods
main: |
from django.test.client import Client
client = Client()
response = client.get('foo')
reveal_type(response.wsgi_request) # N: Revealed type is "django.core.handlers.wsgi.WSGIRequest"
reveal_type(response.request) # N: Revealed type is "builtins.dict[builtins.str, Any]"
reveal_type(response.templates) # N: Revealed type is "builtins.list[django.template.base.Template]"
reveal_type(response.client) # N: Revealed type is "django.test.client.Client"
reveal_type(response.context) # N: Revealed type is "builtins.list[builtins.dict[builtins.str, Any]]"
response.json()
- case: async_client_methods
main: |
from django.test.client import AsyncClient
async def main():
client = AsyncClient()
response = await client.get('foo')
reveal_type(response.asgi_request) # N: Revealed type is "django.core.handlers.asgi.ASGIRequest"
reveal_type(response.request) # N: Revealed type is "builtins.dict[builtins.str, Any]"
reveal_type(response.templates) # N: Revealed type is "builtins.list[django.template.base.Template]"
reveal_type(response.client) # N: Revealed type is "django.test.client.AsyncClient"
reveal_type(response.context) # N: Revealed type is "builtins.list[builtins.dict[builtins.str, Any]]"
response.json()
- case: request_factories
main: |
from django.test.client import RequestFactory, AsyncRequestFactory
factory = RequestFactory()
request = factory.get('foo')
reveal_type(request) # N: Revealed type is "django.core.handlers.wsgi.WSGIRequest*"
async_factory = AsyncRequestFactory()
async_request = async_factory.get('foo')
reveal_type(async_request) # N: Revealed type is "django.core.handlers.asgi.ASGIRequest*"

View File

@@ -9,4 +9,4 @@
reveal_type(resp.status_code) # N: Revealed type is "builtins.int" reveal_type(resp.status_code) # N: Revealed type is "builtins.int"
# Attributes monkey-patched by test Client class: # Attributes monkey-patched by test Client class:
resp.json() resp.json()
reveal_type(resp.wsgi_request) # N: Revealed type is "django.core.handlers.wsgi.WSGIRequest*" reveal_type(resp.wsgi_request) # N: Revealed type is "django.core.handlers.wsgi.WSGIRequest"

View File

@@ -0,0 +1,17 @@
# Regression test for #893
- case: annotated_should_not_iterfere
main: |
from dataclasses import dataclass
import sys
if sys.version_info < (3, 9):
from typing_extensions import Annotated
else:
from typing import Annotated
class IntegerType:
def __init__(self, min_value: int, max_value: int) -> None:
pass
@dataclass(unsafe_hash=True)
class RatingComposite:
max_value: Annotated[int, IntegerType(min_value=1, max_value=10)] = 5

View File

@@ -69,3 +69,41 @@
reveal_type(request.user) # N: Revealed type is "django.contrib.auth.models.User" reveal_type(request.user) # N: Revealed type is "django.contrib.auth.models.User"
custom_settings: | custom_settings: |
INSTALLED_APPS = ('django.contrib.contenttypes', 'django.contrib.auth') INSTALLED_APPS = ('django.contrib.contenttypes', 'django.contrib.auth')
- case: request_get_post
main: |
from django.http.request import HttpRequest
request = HttpRequest()
reveal_type(request) # N: Revealed type is "django.http.request._MutableHttpRequest"
reveal_type(request.GET) # N: Revealed type is "django.http.request.QueryDict"
request.GET['foo'] = 'bar'
def mk_request() -> HttpRequest:
return HttpRequest()
req = mk_request()
reveal_type(req) # N: Revealed type is "django.http.request.HttpRequest"
reveal_type(req.GET) # N: Revealed type is "django.http.request._ImmutableQueryDict"
req.GET.setdefault('foo', 'bar') # E: This QueryDict is immutable.
x = 1 # E: Statement is unreachable
# Will work after merging https://github.com/python/mypy/pull/12572
- case: request_get_post_unreachable
main: |
from django.http.request import HttpRequest
request = HttpRequest()
reveal_type(request) # N: Revealed type is "django.http.request._MutableHttpRequest"
reveal_type(request.GET) # N: Revealed type is "django.http.request.QueryDict"
request.GET['foo'] = 'bar'
def mk_request() -> HttpRequest:
return HttpRequest()
req = mk_request()
reveal_type(req) # N: Revealed type is "django.http.request.HttpRequest"
reveal_type(req.GET) # N: Revealed type is "django.http.request._ImmutableQueryDict"
req.GET['foo'] = 'bar' # E: This QueryDict is immutable.
x = 1 # E: Statement is unreachable
expect_fail: true

View File

@@ -61,3 +61,34 @@
class Article(models.Model): class Article(models.Model):
pass pass
- case: generic_form_views_different_form_classes
main: |
from django.views.generic.edit import CreateView
from django import forms
from myapp.models import Article
class ArticleModelForm(forms.ModelForm[Article]):
class Meta:
model = Article
class SubArticleModelForm(ArticleModelForm):
pass
class AnotherArticleModelForm(forms.ModelForm[Article]):
class Meta:
model = Article
class MyCreateView(CreateView[Article, ArticleModelForm]):
def some(self) -> None:
reveal_type(self.get_form()) # N: Revealed type is "main.ArticleModelForm*"
reveal_type(self.get_form(SubArticleModelForm)) # N: Revealed type is "main.SubArticleModelForm"
reveal_type(self.get_form(AnotherArticleModelForm)) # N: Revealed type is "main.AnotherArticleModelForm" # E: Argument 1 to "get_form" of "FormMixin" has incompatible type "Type[AnotherArticleModelForm]"; expected "Optional[Type[ArticleModelForm]]"
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class Article(models.Model):
pass