Recover after #909 (#925)

* Fix stubs related to `(Async)RequestFactory` and `(Async)Client`

* Revert incorrect removal.

* Allow set as `unique_together`, use shared type alias.

* Revert `Q.__init__` to use only `*args, **kwargs` to remove false-positive with `Q(**{...})`

* Add abstract methods to `HttpResponseBase` to create common interface.

* Remove monkey-patched attributes from `HttpResponseBase` subclasses.

* Add QueryDict mutability checks (+ plugin support)

* Fix lint

* Return back GenericForeignKey to `Options.get_fields`

* Minor fixup

* Make plugin code typecheck with `--warn-unreachable`, minor performance increase.

* Better types for `{unique, index}_together` and Options.

* Fix odd type of `URLResolver.urlconf_name` which isn't a str actually.

* Better types for field migration operations.

* Revert form.files to `MultiValueDict[str, UploadedFile]`

* Compatibility fix (#916)

* Do not assume that `Annotated` is always related to django-stubs (fixes #893)

* Restrict `FormView.get_form` return type to `_FormT` (class type argument). Now it is resolved to `form_class` argument if present, but also errors if it is not subclass of _FormT

* Fix CI (make test runnable on 3.8)

* Fix CI (make test runnable on 3.8 _again_)
This commit is contained in:
sterliakov
2022-04-28 13:01:37 +03:00
committed by GitHub
parent 16499a22ab
commit 6226381484
29 changed files with 380 additions and 138 deletions

View File

@@ -15,7 +15,7 @@ from typing import (
)
from django.core.handlers import base as base
from django.http import HttpRequest, QueryDict
from django.http.request import HttpRequest, _ImmutableQueryDict
from django.http.response import HttpResponseBase
from django.urls.resolvers import ResolverMatch, URLResolver
from django.utils.datastructures import MultiValueDict
@@ -34,8 +34,8 @@ class ASGIRequest(HttpRequest):
META: Dict[str, Any] = ...
def __init__(self, scope: Mapping[str, Any], body_file: IO[bytes]) -> None: ...
@property
def GET(self) -> QueryDict: ... # type: ignore
POST: QueryDict = ...
def GET(self) -> _ImmutableQueryDict: ... # type: ignore
POST: _ImmutableQueryDict = ...
FILES: MultiValueDict = ...
@property
def COOKIES(self) -> Dict[str, str]: ... # type: ignore

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Optional
from django.db.models.fields import Field
@@ -23,15 +23,15 @@ class AddField(FieldOperation):
class RemoveField(FieldOperation): ...
class AlterField(FieldOperation):
field: Any = ...
preserve_default: Any = ...
field: Field = ...
preserve_default: bool = ...
def __init__(self, model_name: str, name: str, field: Field, preserve_default: bool = ...) -> None: ...
class RenameField(FieldOperation):
old_name: Any = ...
new_name: Any = ...
old_name: str = ...
new_name: str = ...
def __init__(self, model_name: str, old_name: str, new_name: str) -> None: ...
@property
def old_name_lower(self): ...
def old_name_lower(self) -> str: ...
@property
def new_name_lower(self) -> str: ...

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.migrations.operations.base import Operation
@@ -7,9 +7,7 @@ from django.db.models.constraints import BaseConstraint
from django.db.models.fields import Field
from django.db.models.indexes import Index
from django.db.models.manager import Manager
from django.utils.datastructures import _ListOrTuple
_T = TypeVar("_T")
from django.db.models.options import _OptionTogetherT
class ModelOperation(Operation):
name: str = ...
@@ -53,10 +51,10 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
def __init__(
self,
name: str,
option_value: _ListOrTuple[Tuple[str, str]],
option_value: Optional[_OptionTogetherT],
) -> None: ...
@property
def option_value(self) -> Set[Tuple[str, str]]: ...
def option_value(self) -> Optional[Set[Tuple[str, ...]]]: ...
def deconstruct(self) -> Tuple[str, Sequence[Any], Dict[str, Any]]: ...
def state_forwards(self, app_label: str, state: Any) -> None: ...
def database_forwards(
@@ -72,26 +70,26 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
class AlterUniqueTogether(AlterTogetherOptionOperation):
option_name: str = ...
unique_together: _ListOrTuple[Tuple[str, str]] = ...
def __init__(self, name: str, unique_together: _ListOrTuple[Tuple[str, str]]) -> None: ...
unique_together: Optional[Set[Tuple[str, ...]]] = ...
def __init__(self, name: str, unique_together: Optional[_OptionTogetherT]) -> None: ...
class AlterIndexTogether(AlterTogetherOptionOperation):
option_name: str = ...
index_together: _ListOrTuple[Tuple[str, str]] = ...
def __init__(self, name: str, index_together: _ListOrTuple[Tuple[str, str]]) -> None: ...
index_together: Optional[Set[Tuple[str, ...]]] = ...
def __init__(self, name: str, index_together: Optional[_OptionTogetherT]) -> None: ...
class AlterOrderWithRespectTo(ModelOptionOperation):
order_with_respect_to: str = ...
def __init__(self, name: str, order_with_respect_to: str) -> None: ...
class AlterModelOptions(ModelOptionOperation):
ALTER_OPTION_KEYS: Any = ...
options: Dict[str, str] = ...
ALTER_OPTION_KEYS: List[str] = ...
options: Dict[str, Any] = ...
def __init__(self, name: str, options: Dict[str, Any]) -> None: ...
class AlterModelManagers(ModelOptionOperation):
managers: Any = ...
def __init__(self, name: Any, managers: Any) -> None: ...
managers: Sequence[Manager] = ...
def __init__(self, name: str, managers: Sequence[Manager]) -> None: ...
class IndexOperation(Operation):
option_name: str = ...

View File

@@ -1,13 +1,9 @@
import sys
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Mapping, Optional, Sequence, Union
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.migrations.state import StateApps
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
from django.utils.datastructures import _ListOrTuple
from .base import Operation
@@ -25,14 +21,14 @@ class SeparateDatabaseAndState(Operation):
class RunSQL(Operation):
noop: Literal[""] = ...
sql: Union[str, List[str], Tuple[str, ...]] = ...
reverse_sql: Optional[Union[str, List[str], Tuple[str, ...]]] = ...
sql: Union[str, _ListOrTuple[str]] = ...
reverse_sql: Optional[Union[str, _ListOrTuple[str]]] = ...
state_operations: Sequence[Operation] = ...
hints: Mapping[str, Any] = ...
def __init__(
self,
sql: Union[str, List[str], Tuple[str, ...]],
reverse_sql: Optional[Union[str, List[str], Tuple[str, ...]]] = ...,
sql: Union[str, _ListOrTuple[str]],
reverse_sql: Optional[Union[str, _ListOrTuple[str]]] = ...,
state_operations: Sequence[Operation] = ...,
hints: Optional[Mapping[str, Any]] = ...,
elidable: bool = ...,
@@ -44,13 +40,13 @@ class _CodeCallable(Protocol):
class RunPython(Operation):
code: _CodeCallable = ...
reverse_code: Optional[_CodeCallable] = ...
hints: Optional[Dict[str, Any]] = ...
hints: Mapping[str, Any] = ...
def __init__(
self,
code: _CodeCallable,
reverse_code: Optional[_CodeCallable] = ...,
atomic: Optional[bool] = ...,
hints: Optional[Dict[str, Any]] = ...,
hints: Optional[Mapping[str, Any]] = ...,
elidable: bool = ...,
) -> None: ...
@staticmethod

View File

@@ -1,4 +1,5 @@
from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union
import sys
from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union, overload
from django.apps.config import AppConfig
from django.apps.registry import Apps
@@ -11,16 +12,26 @@ from django.db.models.fields.related import ManyToManyField, OneToOneField
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.manager import Manager
from django.db.models.query_utils import PathInfo
from django.utils.datastructures import ImmutableList
from django.utils.datastructures import ImmutableList, _ListOrTuple
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
PROXY_PARENTS: object
EMPTY_RELATION_TREE: Any
IMMUTABLE_WARNING: str
DEFAULT_NAMES: Tuple[str, ...]
def normalize_together(
option_together: Union[List[Tuple[str, str]], Tuple[Tuple[str, str], ...], Tuple[()], Tuple[str, str]]
) -> Tuple[Tuple[str, str], ...]: ...
_OptionTogetherT = Union[_ListOrTuple[Union[_ListOrTuple[str], str]], Set[Tuple[str, ...]]]
@overload
def normalize_together(option_together: _ListOrTuple[Union[_ListOrTuple[str], str]]) -> Tuple[Tuple[str, ...], ...]: ...
# Any other value will be returned unchanged, but probably only set is semantically allowed
@overload
def normalize_together(option_together: Set[Tuple[str, ...]]) -> Set[Tuple[str, ...]]: ...
_T = TypeVar("_T")
@@ -45,18 +56,18 @@ class Options(Generic[_M]):
db_table: str = ...
ordering: Optional[Sequence[str]] = ...
indexes: List[Any] = ...
unique_together: Union[Sequence[Tuple[str, str]], Tuple[str, str]] = ...
index_together: Union[Sequence[Tuple[str, str]], Tuple[str, str]] = ...
unique_together: Sequence[Tuple[str]] = ... # Are always normalized
index_together: Sequence[Tuple[str]] = ... # Are always normalized
select_on_save: bool = ...
default_permissions: Sequence[str] = ...
permissions: List[Any] = ...
object_name: Optional[str] = ...
app_label: str = ...
get_latest_by: Optional[Sequence[str]] = ...
order_with_respect_to: Optional[Any] = ...
order_with_respect_to: Optional[str] = ...
db_tablespace: str = ...
required_db_features: List[Any] = ...
required_db_vendor: Any = ...
required_db_features: List[str] = ...
required_db_vendor: Optional[Literal["sqlite", "postgresql", "mysql", "oracle"]] = ...
meta: Optional[type] = ...
pk: Optional[Field] = ...
auto_field: Optional[AutoField] = ...
@@ -105,7 +116,7 @@ class Options(Generic[_M]):
def default_manager(self) -> Optional[Manager]: ...
@property
def fields(self) -> ImmutableList[Field[Any, Any]]: ...
def get_field(self, field_name: str) -> Union[Field, ForeignObjectRel]: ...
def get_field(self, field_name: str) -> Union[Field, ForeignObjectRel, GenericForeignKey]: ...
def get_base_chain(self, model: Type[Model]) -> List[Type[Model]]: ...
def get_parent_list(self) -> List[Type[Model]]: ...
def get_ancestor_link(self, ancestor: Type[Model]) -> Optional[OneToOneField]: ...
@@ -113,7 +124,7 @@ class Options(Generic[_M]):
def get_path_from_parent(self, parent: Type[Model]) -> List[PathInfo]: ...
def get_fields(
self, include_parents: bool = ..., include_hidden: bool = ...
) -> List[Union[Field[Any, Any], ForeignObjectRel]]: ...
) -> List[Union[Field[Any, Any], ForeignObjectRel, GenericForeignKey]]: ...
@property
def total_unique_constraints(self) -> List[UniqueConstraint]: ...
@property

View File

@@ -45,7 +45,9 @@ class Q(tree.Node):
AND: str = ...
OR: str = ...
conditional: bool = ...
def __init__(self, *args: Any, _connector: Optional[Any] = ..., _negated: bool = ..., **kwargs: Any) -> None: ...
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
# Fake signature, the real is
# def __init__(self, *args: Any, _connector: Optional[Any] = ..., _negated: bool = ..., **kwargs: Any) -> None: ...
def __or__(self, other: Q) -> Q: ...
def __and__(self, other: Q) -> Q: ...
def __invert__(self) -> Q: ...

View File

@@ -41,6 +41,8 @@ class ConnectionRouter:
def __init__(self, routers: Optional[Iterable[Any]] = ...) -> None: ...
@property
def routers(self) -> List[Any]: ...
def db_for_read(self, model: Type[Model], **hints: Any) -> str: ...
def db_for_write(self, model: Type[Model], **hints: Any) -> str: ...
def allow_relation(self, obj1: Model, obj2: Model, **hints: Any) -> bool: ...
def allow_migrate(self, db: str, app_label: str, **hints: Any) -> bool: ...
def allow_migrate_model(self, db: str, model: Type[Model]) -> bool: ...

View File

@@ -3,11 +3,12 @@ from datetime import datetime
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union
from django.core.exceptions import ValidationError
from django.core.files import File
from django.core.files.uploadedfile import UploadedFile
from django.utils.datastructures import MultiValueDict
from django.utils.safestring import SafeString
_DataT = Mapping[str, Any]
_FilesT = Mapping[str, Iterable[File]]
_FilesT = MultiValueDict[str, UploadedFile]
def pretty_name(name: str) -> str: ...
def flatatt(attrs: Dict[str, Any]) -> SafeString: ...

View File

@@ -1,3 +1,4 @@
import sys
from io import BytesIO
from typing import (
Any,
@@ -6,6 +7,7 @@ from typing import (
Iterable,
List,
Mapping,
NoReturn,
Optional,
Pattern,
Set,
@@ -24,6 +26,11 @@ from django.core.files import uploadedfile, uploadhandler
from django.urls import ResolverMatch
from django.utils.datastructures import CaseInsensitiveMapping, ImmutableList, MultiValueDict
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
RAISE_ERROR: object = ...
host_validation_re: Pattern[str] = ...
@@ -40,8 +47,8 @@ class HttpHeaders(CaseInsensitiveMapping[str]):
def parse_header_name(cls, header: str) -> Optional[str]: ...
class HttpRequest(BytesIO):
GET: QueryDict = ...
POST: QueryDict = ...
GET: _ImmutableQueryDict = ...
POST: _ImmutableQueryDict = ...
COOKIES: Dict[str, str] = ...
META: Dict[str, Any] = ...
FILES: MultiValueDict[str, uploadedfile.UploadedFile] = ...
@@ -55,7 +62,15 @@ class HttpRequest(BytesIO):
site: Site
session: SessionBase
_stream: BinaryIO
def __init__(self) -> None: ...
# The magic. If we instantiate HttpRequest directly somewhere, it has
# mutable GET and POST. However, both ASGIRequest and WSGIRequest have immutable,
# so when we use HttpRequest to refer to any of them we want exactly this.
# Case when some function creates *exactly* HttpRequest (not subclass)
# remain uncovered, however it's probably the best solution we can afford.
def __new__(cls) -> _MutableHttpRequest: ...
# When both __init__ and __new__ are present, mypy will prefer __init__
# (see comments in mypy.checkmember.type_object_type)
# def __init__(self) -> None: ...
def get_host(self) -> str: ...
def get_port(self) -> str: ...
def get_full_path(self, force_append_slash: bool = ...) -> str: ...
@@ -90,17 +105,43 @@ class HttpRequest(BytesIO):
def _load_post_and_files(self) -> None: ...
def accepts(self, media_type: str) -> bool: ...
class _MutableHttpRequest(HttpRequest):
GET: QueryDict = ... # type: ignore[assignment]
POST: QueryDict = ... # type: ignore[assignment]
_Q = TypeVar("_Q", bound="QueryDict")
_Z = TypeVar("_Z")
class QueryDict(MultiValueDict[str, str]):
_mutable: bool = ...
# We can make it mutable only by specifying `mutable=True`.
# It can be done a) with kwarg and b) with pos. arg. `overload` has
# some problems with args/kwargs + Literal, so two signatures are required.
# ('querystring', True, [...])
@overload
def __init__(
self, query_string: Optional[Union[str, bytes]] = ..., mutable: bool = ..., encoding: Optional[str] = ...
self: QueryDict,
query_string: Optional[Union[str, bytes]],
mutable: Literal[True],
encoding: Optional[str] = ...,
) -> None: ...
# ([querystring='string',] mutable=True, [...])
@overload
def __init__(
self: QueryDict,
*,
mutable: Literal[True],
query_string: Optional[Union[str, bytes]] = ...,
encoding: Optional[str] = ...,
) -> None: ...
# Otherwise it's immutable
@overload
def __init__( # type: ignore[misc]
self: _ImmutableQueryDict,
query_string: Optional[Union[str, bytes]] = ...,
mutable: bool = ...,
encoding: Optional[str] = ...,
) -> None: ...
def setlist(self, key: Union[str, bytes], list_: Iterable[Union[str, bytes]]) -> None: ...
def setlistdefault(self, key: Union[str, bytes], default_list: Optional[List[str]] = ...) -> List[str]: ...
def appendlist(self, key: Union[str, bytes], value: Union[str, bytes]) -> None: ...
def urlencode(self, safe: Optional[str] = ...) -> str: ...
@classmethod
def fromkeys( # type: ignore
cls: Type[_Q],
@@ -114,7 +155,46 @@ class QueryDict(MultiValueDict[str, str]):
@encoding.setter
def encoding(self, value: str) -> None: ...
def __setitem__(self, key: Union[str, bytes], value: Union[str, bytes]) -> None: ...
def __delitem__(self, key: Union[str, bytes]) -> None: ...
def setlist(self, key: Union[str, bytes], list_: Iterable[Union[str, bytes]]) -> None: ...
def setlistdefault(self, key: Union[str, bytes], default_list: Optional[List[str]] = ...) -> List[str]: ...
def appendlist(self, key: Union[str, bytes], value: Union[str, bytes]) -> None: ...
# Fake signature (because *args is used in source, but it fails with more that 1 argument)
@overload
def pop(self, key: Union[str, bytes], /) -> str: ...
@overload
def pop(self, key: Union[str, bytes], default: Union[str, _Z] = ..., /) -> Union[str, _Z]: ...
def popitem(self) -> Tuple[str, str]: ...
def clear(self) -> None: ...
def setdefault(self, key: Union[str, bytes], default: Union[str, bytes, None] = ...) -> str: ...
def copy(self) -> QueryDict: ...
def urlencode(self, safe: Optional[str] = ...) -> str: ...
class _ImmutableQueryDict(QueryDict):
_mutable: Literal[False]
# def __init__(
# self, query_string: Optional[Union[str, bytes]] = ..., mutable: bool = ..., encoding: Optional[str] = ...
# ) -> None: ...
def __setitem__(self, key: Union[str, bytes], value: Union[str, bytes]) -> NoReturn: ...
def __delitem__(self, key: Union[str, bytes]) -> NoReturn: ...
def setlist(self, key: Union[str, bytes], list_: Iterable[Union[str, bytes]]) -> NoReturn: ...
def setlistdefault(self, key: Union[str, bytes], default_list: Optional[List[str]] = ...) -> NoReturn: ...
def appendlist(self, key: Union[str, bytes], value: Union[str, bytes]) -> NoReturn: ...
# Fake signature (because *args is used in source, but it fails with more that 1 argument)
@overload
def pop(self, key: Union[str, bytes], /) -> NoReturn: ...
@overload
def pop(self, key: Union[str, bytes], default: Union[str, _Z] = ..., /) -> NoReturn: ...
def popitem(self) -> NoReturn: ...
def clear(self) -> NoReturn: ...
def setdefault(self, key: Union[str, bytes], default: Union[str, bytes, None] = ...) -> NoReturn: ...
def copy(self) -> QueryDict: ... # type: ignore[override]
def urlencode(self, safe: Optional[str] = ...) -> str: ...
# Fakes for convenience (for `request.GET` and `request.POST`). If dict
# was created by Django, there is no chance to hit `List[object]` (empty list)
# edge case.
def __getitem__(self, key: str) -> str: ...
def dict(self) -> Dict[str, str]: ... # type: ignore[override]
class MediaType:
def __init__(self, media_type_raw_line: str) -> None: ...

View File

@@ -1,13 +1,9 @@
import datetime
from io import BytesIO
from json import JSONEncoder
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, overload
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, overload, type_check_only
from django.core.handlers.wsgi import WSGIRequest
from django.http.cookie import SimpleCookie
from django.template import Context, Template
from django.test.client import Client
from django.urls import ResolverMatch
from django.utils.datastructures import CaseInsensitiveMapping, _PropertyDescriptor
class BadHeaderError(ValueError): ...
@@ -82,6 +78,12 @@ class HttpResponseBase:
def writable(self) -> bool: ...
def writelines(self, lines: Iterable[object]) -> None: ...
# Fake methods that are implemented by all subclasses
@type_check_only
def __iter__(self) -> Iterator[bytes]: ...
@type_check_only
def getvalue(self) -> bytes: ...
class HttpResponse(HttpResponseBase, Iterable[bytes]):
content = _PropertyDescriptor[object, bytes]()
csrf_cookie_set: bool
@@ -94,14 +96,6 @@ class HttpResponse(HttpResponseBase, Iterable[bytes]):
def serialize(self) -> bytes: ...
__bytes__ = serialize
def __iter__(self) -> Iterator[bytes]: ...
# Attributes assigned by monkey-patching in test client ClientHandler.__call__()
wsgi_request: WSGIRequest
# Attributes assigned by monkey-patching in test client Client.request()
client: Client
request: Dict[str, Any]
templates: List[Template]
context: Context
resolver_match: ResolverMatch
def getvalue(self) -> bytes: ...
class StreamingHttpResponse(HttpResponseBase, Iterable[bytes]):
@@ -111,13 +105,7 @@ class StreamingHttpResponse(HttpResponseBase, Iterable[bytes]):
def getvalue(self) -> bytes: ...
class FileResponse(StreamingHttpResponse):
client: Client
context: None
file_to_stream: Optional[BytesIO]
request: Dict[str, str]
resolver_match: ResolverMatch
templates: List[Any]
wsgi_request: WSGIRequest
block_size: int = ...
as_attachment: bool = ...
filename: str = ...

View File

@@ -3,6 +3,7 @@ from json import JSONEncoder
from types import TracebackType
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generic,
@@ -26,6 +27,8 @@ from django.core.handlers.wsgi import WSGIRequest
from django.http.cookie import SimpleCookie
from django.http.request import HttpRequest
from django.http.response import HttpResponseBase
from django.template.base import Template
from django.urls import ResolverMatch
BOUNDARY: str = ...
MULTIPART_CONTENT: str = ...
@@ -49,27 +52,25 @@ _T = TypeVar("_T")
def closing_iterator_wrapper(iterable: Iterable[_T], close: Callable[[], Any]) -> Iterator[_T]: ...
def conditional_content_removal(request: HttpRequest, response: HttpResponseBase) -> HttpResponseBase: ...
class _WSGIResponse(HttpResponseBase):
wsgi_request: WSGIRequest
class _ASGIResponse(HttpResponseBase):
asgi_request: ASGIRequest
class ClientHandler(BaseHandler):
enforce_csrf_checks: bool = ...
def __init__(self, enforce_csrf_checks: bool = ..., *args: Any, **kwargs: Any) -> None: ...
def __call__(self, environ: Dict[str, Any]) -> HttpResponseBase: ...
def __call__(self, environ: Dict[str, Any]) -> _WSGIResponse: ...
class AsyncClientHandler(BaseHandler):
enforce_csrf_checks: bool = ...
def __init__(self, enforce_csrf_checks: bool = ..., *args: Any, **kwargs: Any) -> None: ...
async def __call__(self, scope: Dict[str, Any]) -> HttpResponseBase: ...
async def __call__(self, scope: Dict[str, Any]) -> _ASGIResponse: ...
def encode_multipart(boundary: str, data: Dict[str, Any]) -> bytes: ...
def encode_file(boundary: str, key: str, file: Any) -> List[bytes]: ...
# fake to distinguish WSGIRequest and ASGIRequest
_R = TypeVar("_R", bound=HttpRequest)
class _MonkeyPatchedHttpResponseBase(Generic[_R], HttpResponseBase):
def json(self) -> Any: ...
wsgi_request: _R
class _RequestFactory(Generic[_T]):
json_encoder: Type[JSONEncoder]
defaults: Dict[str, str]
@@ -102,9 +103,9 @@ class _RequestFactory(Generic[_T]):
**extra: Any
) -> _T: ...
class RequestFactory(_RequestFactory[_T]): ...
class RequestFactory(_RequestFactory[WSGIRequest]): ...
class AsyncRequestFactory(_RequestFactory[_T]):
class _AsyncRequestFactory(_RequestFactory[_T]):
def request(self, **request: Any) -> _T: ...
def generic(
self,
@@ -116,6 +117,25 @@ class AsyncRequestFactory(_RequestFactory[_T]):
**extra: Any
) -> _T: ...
class AsyncRequestFactory(_AsyncRequestFactory[ASGIRequest]): ...
# fakes to distinguish WSGIRequest and ASGIRequest
class _MonkeyPatchedWSGIResponse(_WSGIResponse):
def json(self) -> Any: ...
request: Dict[str, Any]
client: Client
templates: List[Template]
context: List[Dict[str, Any]]
resolver_match: ResolverMatch
class _MonkeyPatchedASGIResponse(_ASGIResponse):
def json(self) -> Any: ...
request: Dict[str, Any]
client: AsyncClient
templates: List[Template]
context: List[Dict[str, Any]]
resolver_match: ResolverMatch
class ClientMixin:
def store_exc_info(self, **kwargs: Any) -> None: ...
def check_exception(self, response: HttpResponseBase) -> NoReturn: ...
@@ -125,7 +145,7 @@ class ClientMixin:
def force_login(self, user: AbstractBaseUser, backend: Optional[str] = ...) -> None: ...
def logout(self) -> None: ...
class Client(ClientMixin, RequestFactory[_MonkeyPatchedHttpResponseBase[WSGIRequest]]):
class Client(ClientMixin, _RequestFactory[_MonkeyPatchedWSGIResponse]):
handler: ClientHandler
raise_request_exception: bool
exc_info: Optional[Tuple[Type[BaseException], BaseException, TracebackType]]
@@ -133,19 +153,19 @@ class Client(ClientMixin, RequestFactory[_MonkeyPatchedHttpResponseBase[WSGIRequ
self, enforce_csrf_checks: bool = ..., raise_request_exception: bool = ..., **defaults: Any
) -> None: ...
# Silence type warnings, since this class overrides arguments and return types in an unsafe manner.
def request(self, **request: Any) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
def request(self, **request: Any) -> _MonkeyPatchedWSGIResponse: ...
def get( # type: ignore
self, path: str, data: Any = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
) -> _MonkeyPatchedWSGIResponse: ...
def post( # type: ignore
self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
) -> _MonkeyPatchedWSGIResponse: ...
def head( # type: ignore
self, path: str, data: Any = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
) -> _MonkeyPatchedWSGIResponse: ...
def trace( # type: ignore
self, path: str, data: Any = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
) -> _MonkeyPatchedWSGIResponse: ...
def options( # type: ignore
self,
path: str,
@@ -154,18 +174,18 @@ class Client(ClientMixin, RequestFactory[_MonkeyPatchedHttpResponseBase[WSGIRequ
follow: bool = ...,
secure: bool = ...,
**extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
) -> _MonkeyPatchedWSGIResponse: ...
def put( # type: ignore
self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
) -> _MonkeyPatchedWSGIResponse: ...
def patch( # type: ignore
self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
) -> _MonkeyPatchedWSGIResponse: ...
def delete( # type: ignore
self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
) -> _MonkeyPatchedWSGIResponse: ...
class AsyncClient(ClientMixin, AsyncRequestFactory[_MonkeyPatchedHttpResponseBase[ASGIRequest]]):
class AsyncClient(ClientMixin, _AsyncRequestFactory[Awaitable[_MonkeyPatchedASGIResponse]]):
handler: AsyncClientHandler
raise_request_exception: bool
exc_info: Any
@@ -173,4 +193,4 @@ class AsyncClient(ClientMixin, AsyncRequestFactory[_MonkeyPatchedHttpResponseBas
def __init__(
self, enforce_csrf_checks: bool = ..., raise_request_exception: bool = ..., **defaults: Any
) -> None: ...
async def request(self, **request: Any) -> _MonkeyPatchedHttpResponseBase[ASGIRequest]: ... # type: ignore
async def request(self, **request: Any) -> _MonkeyPatchedASGIResponse: ...

View File

@@ -28,7 +28,7 @@ from django.db.models.query import QuerySet, RawQuerySet
from django.forms.fields import EmailField
from django.http.response import HttpResponse, HttpResponseBase
from django.template.base import Template
from django.test.client import Client
from django.test.client import AsyncClient, Client
from django.test.html import Element
from django.test.utils import CaptureQueriesContext, ContextList
from django.utils.functional import classproperty
@@ -64,8 +64,10 @@ class _DatabaseFailure:
def __call__(self) -> None: ...
class SimpleTestCase(unittest.TestCase):
client_class: Any = ...
client_class: Type[Client] = ...
client: Client
async_client_class: Type[AsyncClient] = ...
async_client: AsyncClient
allow_database_queries: bool = ...
# TODO: str -> Literal['__all__']
databases: Union[Set[str], str] = ...
@@ -142,9 +144,7 @@ class SimpleTestCase(unittest.TestCase):
) -> Any: ...
def assertHTMLEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ...
def assertHTMLNotEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ...
def assertInHTML(
self, needle: str, haystack: SafeString, count: Optional[int] = ..., msg_prefix: str = ...
) -> None: ...
def assertInHTML(self, needle: str, haystack: str, count: Optional[int] = ..., msg_prefix: str = ...) -> None: ...
def assertJSONEqual(
self,
raw: str,

View File

@@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Pattern, Tuple, Type, Union, overload
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Pattern, Sequence, Tuple, Type, Union, overload
from django.core.checks.messages import CheckMessage
from django.urls.converters import UUIDConverter
@@ -93,7 +94,7 @@ class URLPattern:
class URLResolver:
pattern: LocalePrefixPattern = ...
urlconf_name: Optional[str] = ...
urlconf_name: Union[str, None, Sequence[Union[URLPattern, URLResolver]]] = ...
callback: None = ...
default_kwargs: Dict[str, Any] = ...
namespace: Optional[str] = ...
@@ -103,7 +104,7 @@ class URLResolver:
def __init__(
self,
pattern: LocalePrefixPattern,
urlconf_name: Optional[str],
urlconf_name: Union[str, None, Sequence[Union[URLPattern, URLResolver]]],
default_kwargs: Optional[Dict[str, Any]] = ...,
app_name: Optional[str] = ...,
namespace: Optional[str] = ...,
@@ -115,7 +116,7 @@ class URLResolver:
@property
def app_dict(self) -> Dict[str, List[str]]: ...
@property
def urlconf_module(self) -> Optional[List[Tuple[str, Callable]]]: ...
def urlconf_module(self) -> Union[ModuleType, None, Sequence[Union[URLPattern, URLResolver]]]: ...
@property
def url_patterns(self) -> List[Union[URLPattern, URLResolver]]: ...
def resolve(self, path: str) -> ResolverMatch: ...

View File

@@ -72,7 +72,7 @@ class classproperty(Generic[_Get]):
fget: Optional[Callable[[_Self], _Get]] = ...
def __init__(self, method: Optional[Callable[[_Self], _Get]] = ...) -> None: ...
def __get__(self, instance: Optional[_Self], cls: Type[_Self] = ...) -> _Get: ...
def getter(self, method: Callable[[Type[_Self]], _Get]) -> classproperty[_Get]: ...
def getter(self, method: Callable[[_Self], _Get]) -> classproperty[_Get]: ...
class _Getter(Protocol[_Get]):
"""Type fake to declare some read-only properties (until `property` builtin is generic)

View File

@@ -26,7 +26,7 @@ class FormMixin(Generic[_FormT], ContextMixin):
def get_initial(self) -> Dict[str, Any]: ...
def get_prefix(self) -> Optional[str]: ...
def get_form_class(self) -> Type[_FormT]: ...
def get_form(self, form_class: Optional[Type[_FormT]] = ...) -> BaseForm: ...
def get_form(self, form_class: Optional[Type[_FormT]] = ...) -> _FormT: ...
def get_form_kwargs(self) -> Dict[str, Any]: ...
def get_success_url(self) -> str: ...
def form_valid(self, form: _FormT) -> HttpResponse: ...

View File

@@ -1,5 +1,10 @@
from typing import Any, Protocol
import sys
from typing import Any
if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
# Used internally by mypy_django_plugin.
class AnyAttrAllowed(Protocol):

View File

@@ -9,6 +9,7 @@ warn_no_return = False
warn_unused_ignores = True
warn_redundant_casts = True
warn_unused_configs = True
warn_unreachable = True
plugins =
mypy_django_plugin.main

View File

@@ -35,6 +35,7 @@ except ImportError:
if TYPE_CHECKING:
from django.apps.registry import Apps # noqa: F401
from django.conf import LazySettings # noqa: F401
from django.contrib.contenttypes.fields import GenericForeignKey
@contextmanager
@@ -100,9 +101,6 @@ class DjangoContext:
@cached_property
def model_modules(self) -> Dict[str, Set[Type[Model]]]:
"""All modules that contain Django models."""
if self.apps_registry is None:
return {}
modules: Dict[str, Set[Type[Model]]] = defaultdict(set)
for concrete_model_cls in self.apps_registry.get_models():
modules[concrete_model_cls.__module__].add(concrete_model_cls)
@@ -327,7 +325,7 @@ class DjangoContext:
related_model_cls = field.field.model
if isinstance(related_model_cls, str):
if related_model_cls == "self":
if related_model_cls == "self": # type: ignore[unreachable]
# same model
related_model_cls = field.model
elif "." not in related_model_cls:
@@ -343,7 +341,7 @@ class DjangoContext:
self, field_parts: Iterable[str], model_cls: Type[Model]
) -> Union[Field, ForeignObjectRel]:
currently_observed_model = model_cls
field: Union[Field, ForeignObjectRel, None] = None
field: Union[Field, ForeignObjectRel, GenericForeignKey, None] = None
for field_part in field_parts:
if field_part == "pk":
field = self.get_primary_key_field(currently_observed_model)
@@ -359,15 +357,16 @@ class DjangoContext:
if isinstance(field, ForeignObjectRel):
currently_observed_model = field.related_model
assert field is not None
# Guaranteed by `query.solve_lookup_type` before.
assert isinstance(field, (Field, ForeignObjectRel))
return field
def resolve_lookup_into_field(self, model_cls: Type[Model], lookup: str) -> Union[Field, ForeignObjectRel]:
query = Query(model_cls)
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
if lookup_parts:
raise LookupsAreUnsupported()
return self._resolve_field_from_parts(field_parts, model_cls)
def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType:

View File

@@ -35,6 +35,7 @@ RELATED_FIELDS_CLASSES = {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME, MANYTOM
MIGRATION_CLASS_FULLNAME = "django.db.migrations.migration.Migration"
OPTIONS_CLASS_FULLNAME = "django.db.models.options.Options"
HTTPREQUEST_CLASS_FULLNAME = "django.http.request.HttpRequest"
QUERYDICT_CLASS_FULLNAME = "django.http.request.QueryDict"
COMBINABLE_EXPRESSION_FULLNAME = "django.db.models.expressions.Combinable"
F_EXPRESSION_FULLNAME = "django.db.models.expressions.F"

View File

@@ -32,6 +32,7 @@ from mypy_django_plugin.transformers.models import (
process_model_class,
set_auth_user_model_boolean_fields,
)
from mypy_django_plugin.transformers.request import check_querydict_is_mutable
def transform_model_class(ctx: ClassDefContext, django_context: DjangoContext) -> None:
@@ -187,12 +188,21 @@ class NewSemanalDjangoPlugin(Plugin):
def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], MypyType]]:
class_fullname, _, method_name = fullname.rpartition(".")
if method_name == "get_form_class":
# It is looked up very often, specialcase this method for minor speed up
if method_name == "__init_subclass__":
return None
if class_fullname.endswith("QueryDict"):
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYDICT_CLASS_FULLNAME):
return partial(check_querydict_is_mutable, django_context=self.django_context)
elif method_name == "get_form_class":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return forms.extract_proper_type_for_get_form_class
if method_name == "get_form":
elif method_name == "get_form":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return forms.extract_proper_type_for_get_form
@@ -204,30 +214,30 @@ class NewSemanalDjangoPlugin(Plugin):
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
if method_name == "values_list":
elif method_name == "values_list":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
if method_name == "annotate":
elif method_name == "annotate":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
return partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context)
if method_name == "get_field":
elif method_name == "get_field":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)
if class_fullname in manager_classes and method_name == "create":
elif class_fullname in manager_classes and method_name == "create":
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
if class_fullname in manager_classes and method_name in {"filter", "get", "exclude"}:
elif class_fullname in manager_classes and method_name in {"filter", "get", "exclude"}:
return partial(
mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter,
django_context=self.django_context,
)
if method_name == "from_queryset":
elif method_name == "from_queryset":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
return fail_if_manager_type_created_in_model_body

View File

@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
from django.db.models.fields import AutoField, Field
from django.db.models.fields.related import RelatedField
@@ -12,10 +12,13 @@ from mypy.types import TypeOfAny, UnionType
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
if TYPE_CHECKING:
from django.contrib.contenttypes.fields import GenericForeignKey
def _get_current_field_from_assignment(
ctx: FunctionContext, django_context: DjangoContext
) -> Optional[Union[Field, ForeignObjectRel]]:
) -> Optional[Union[Field, ForeignObjectRel, "GenericForeignKey"]]:
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
if outer_model_info is None or not helpers.is_model_subclass_info(outer_model_info, django_context):
return None

View File

@@ -263,8 +263,6 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
semanal_api.defer()
return None
original_return_type = method_type.ret_type
if original_return_type is None:
continue
# Skip any method that doesn't return _QS
original_return_type = get_proper_type(original_return_type)

View File

@@ -15,7 +15,7 @@ from mypy.types import TypedDictType, TypeOfAny
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME, MODEL_CLASS_FULLNAME
from mypy_django_plugin.lib.helpers import add_new_class_for_module
from mypy_django_plugin.transformers import fields
from mypy_django_plugin.transformers.fields import get_field_descriptor_types
@@ -475,8 +475,8 @@ def handle_annotated_type(ctx: AnalyzeTypeContext, django_context: DjangoContext
type_arg = ctx.api.analyze_type(args[0])
api = cast(SemanticAnalyzer, ctx.api.api) # type: ignore
if not isinstance(type_arg, Instance):
return ctx.api.analyze_type(ctx.type)
if not isinstance(type_arg, Instance) or not type_arg.type.has_base(MODEL_CLASS_FULLNAME):
return type_arg
fields_dict = None
if len(args) > 1:

View File

@@ -1,7 +1,7 @@
from mypy.plugin import AttributeContext
from mypy.plugin import AttributeContext, MethodContext
from mypy.types import Instance
from mypy.types import Type as MypyType
from mypy.types import UnionType
from mypy.types import UninhabitedType, UnionType
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import helpers
@@ -35,3 +35,10 @@ def set_auth_user_model_as_type_for_request_user(ctx: AttributeContext, django_c
return ctx.default_attr_type
return UnionType([Instance(user_info, []), Instance(anonymous_user_info, [])])
def check_querydict_is_mutable(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
ret_type = ctx.default_return_type
if isinstance(ret_type, UninhabitedType):
ctx.api.fail("This QueryDict is immutable.", ctx.context)
return ret_type

View File

@@ -0,0 +1,33 @@
- case: client_methods
main: |
from django.test.client import Client
client = Client()
response = client.get('foo')
reveal_type(response.wsgi_request) # N: Revealed type is "django.core.handlers.wsgi.WSGIRequest"
reveal_type(response.request) # N: Revealed type is "builtins.dict[builtins.str, Any]"
reveal_type(response.templates) # N: Revealed type is "builtins.list[django.template.base.Template]"
reveal_type(response.client) # N: Revealed type is "django.test.client.Client"
reveal_type(response.context) # N: Revealed type is "builtins.list[builtins.dict[builtins.str, Any]]"
response.json()
- case: async_client_methods
main: |
from django.test.client import AsyncClient
async def main():
client = AsyncClient()
response = await client.get('foo')
reveal_type(response.asgi_request) # N: Revealed type is "django.core.handlers.asgi.ASGIRequest"
reveal_type(response.request) # N: Revealed type is "builtins.dict[builtins.str, Any]"
reveal_type(response.templates) # N: Revealed type is "builtins.list[django.template.base.Template]"
reveal_type(response.client) # N: Revealed type is "django.test.client.AsyncClient"
reveal_type(response.context) # N: Revealed type is "builtins.list[builtins.dict[builtins.str, Any]]"
response.json()
- case: request_factories
main: |
from django.test.client import RequestFactory, AsyncRequestFactory
factory = RequestFactory()
request = factory.get('foo')
reveal_type(request) # N: Revealed type is "django.core.handlers.wsgi.WSGIRequest*"
async_factory = AsyncRequestFactory()
async_request = async_factory.get('foo')
reveal_type(async_request) # N: Revealed type is "django.core.handlers.asgi.ASGIRequest*"

View File

@@ -9,4 +9,4 @@
reveal_type(resp.status_code) # N: Revealed type is "builtins.int"
# Attributes monkey-patched by test Client class:
resp.json()
reveal_type(resp.wsgi_request) # N: Revealed type is "django.core.handlers.wsgi.WSGIRequest*"
reveal_type(resp.wsgi_request) # N: Revealed type is "django.core.handlers.wsgi.WSGIRequest"

View File

@@ -0,0 +1,17 @@
# Regression test for #893
- case: annotated_should_not_iterfere
main: |
from dataclasses import dataclass
import sys
if sys.version_info < (3, 9):
from typing_extensions import Annotated
else:
from typing import Annotated
class IntegerType:
def __init__(self, min_value: int, max_value: int) -> None:
pass
@dataclass(unsafe_hash=True)
class RatingComposite:
max_value: Annotated[int, IntegerType(min_value=1, max_value=10)] = 5

View File

@@ -69,3 +69,41 @@
reveal_type(request.user) # N: Revealed type is "django.contrib.auth.models.User"
custom_settings: |
INSTALLED_APPS = ('django.contrib.contenttypes', 'django.contrib.auth')
- case: request_get_post
main: |
from django.http.request import HttpRequest
request = HttpRequest()
reveal_type(request) # N: Revealed type is "django.http.request._MutableHttpRequest"
reveal_type(request.GET) # N: Revealed type is "django.http.request.QueryDict"
request.GET['foo'] = 'bar'
def mk_request() -> HttpRequest:
return HttpRequest()
req = mk_request()
reveal_type(req) # N: Revealed type is "django.http.request.HttpRequest"
reveal_type(req.GET) # N: Revealed type is "django.http.request._ImmutableQueryDict"
req.GET.setdefault('foo', 'bar') # E: This QueryDict is immutable.
x = 1 # E: Statement is unreachable
# Will work after merging https://github.com/python/mypy/pull/12572
- case: request_get_post_unreachable
main: |
from django.http.request import HttpRequest
request = HttpRequest()
reveal_type(request) # N: Revealed type is "django.http.request._MutableHttpRequest"
reveal_type(request.GET) # N: Revealed type is "django.http.request.QueryDict"
request.GET['foo'] = 'bar'
def mk_request() -> HttpRequest:
return HttpRequest()
req = mk_request()
reveal_type(req) # N: Revealed type is "django.http.request.HttpRequest"
reveal_type(req.GET) # N: Revealed type is "django.http.request._ImmutableQueryDict"
req.GET['foo'] = 'bar' # E: This QueryDict is immutable.
x = 1 # E: Statement is unreachable
expect_fail: true

View File

@@ -61,3 +61,34 @@
class Article(models.Model):
pass
- case: generic_form_views_different_form_classes
main: |
from django.views.generic.edit import CreateView
from django import forms
from myapp.models import Article
class ArticleModelForm(forms.ModelForm[Article]):
class Meta:
model = Article
class SubArticleModelForm(ArticleModelForm):
pass
class AnotherArticleModelForm(forms.ModelForm[Article]):
class Meta:
model = Article
class MyCreateView(CreateView[Article, ArticleModelForm]):
def some(self) -> None:
reveal_type(self.get_form()) # N: Revealed type is "main.ArticleModelForm*"
reveal_type(self.get_form(SubArticleModelForm)) # N: Revealed type is "main.SubArticleModelForm"
reveal_type(self.get_form(AnotherArticleModelForm)) # N: Revealed type is "main.AnotherArticleModelForm" # E: Argument 1 to "get_form" of "FormMixin" has incompatible type "Type[AnotherArticleModelForm]"; expected "Optional[Type[ArticleModelForm]]"
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class Article(models.Model):
pass