mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-15 00:07:09 +08:00
* Fix stubs related to `(Async)RequestFactory` and `(Async)Client`
* Revert incorrect removal.
* Allow set as `unique_together`, use shared type alias.
* Revert `Q.__init__` to use only `*args, **kwargs` to remove false-positive with `Q(**{...})`
* Add abstract methods to `HttpResponseBase` to create common interface.
* Remove monkey-patched attributes from `HttpResponseBase` subclasses.
* Add QueryDict mutability checks (+ plugin support)
* Fix lint
* Return back GenericForeignKey to `Options.get_fields`
* Minor fixup
* Make plugin code typecheck with `--warn-unreachable`, minor performance increase.
* Better types for `{unique, index}_together` and Options.
* Fix odd type of `URLResolver.urlconf_name` which isn't a str actually.
* Better types for field migration operations.
* Revert form.files to `MultiValueDict[str, UploadedFile]`
* Compatibility fix (#916)
* Do not assume that `Annotated` is always related to django-stubs (fixes #893)
* Restrict `FormView.get_form` return type to `_FormT` (class type argument). Now it is resolved to `form_class` argument if present, but also errors if it is not subclass of _FormT
* Fix CI (make test runnable on 3.8)
* Fix CI (make test runnable on 3.8 _again_)
This commit is contained in:
@@ -15,7 +15,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from django.core.handlers import base as base
|
||||
from django.http import HttpRequest, QueryDict
|
||||
from django.http.request import HttpRequest, _ImmutableQueryDict
|
||||
from django.http.response import HttpResponseBase
|
||||
from django.urls.resolvers import ResolverMatch, URLResolver
|
||||
from django.utils.datastructures import MultiValueDict
|
||||
@@ -34,8 +34,8 @@ class ASGIRequest(HttpRequest):
|
||||
META: Dict[str, Any] = ...
|
||||
def __init__(self, scope: Mapping[str, Any], body_file: IO[bytes]) -> None: ...
|
||||
@property
|
||||
def GET(self) -> QueryDict: ... # type: ignore
|
||||
POST: QueryDict = ...
|
||||
def GET(self) -> _ImmutableQueryDict: ... # type: ignore
|
||||
POST: _ImmutableQueryDict = ...
|
||||
FILES: MultiValueDict = ...
|
||||
@property
|
||||
def COOKIES(self) -> Dict[str, str]: ... # type: ignore
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from django.db.models.fields import Field
|
||||
|
||||
@@ -23,15 +23,15 @@ class AddField(FieldOperation):
|
||||
class RemoveField(FieldOperation): ...
|
||||
|
||||
class AlterField(FieldOperation):
|
||||
field: Any = ...
|
||||
preserve_default: Any = ...
|
||||
field: Field = ...
|
||||
preserve_default: bool = ...
|
||||
def __init__(self, model_name: str, name: str, field: Field, preserve_default: bool = ...) -> None: ...
|
||||
|
||||
class RenameField(FieldOperation):
|
||||
old_name: Any = ...
|
||||
new_name: Any = ...
|
||||
old_name: str = ...
|
||||
new_name: str = ...
|
||||
def __init__(self, model_name: str, old_name: str, new_name: str) -> None: ...
|
||||
@property
|
||||
def old_name_lower(self): ...
|
||||
def old_name_lower(self) -> str: ...
|
||||
@property
|
||||
def new_name_lower(self) -> str: ...
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union
|
||||
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.migrations.operations.base import Operation
|
||||
@@ -7,9 +7,7 @@ from django.db.models.constraints import BaseConstraint
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.indexes import Index
|
||||
from django.db.models.manager import Manager
|
||||
from django.utils.datastructures import _ListOrTuple
|
||||
|
||||
_T = TypeVar("_T")
|
||||
from django.db.models.options import _OptionTogetherT
|
||||
|
||||
class ModelOperation(Operation):
|
||||
name: str = ...
|
||||
@@ -53,10 +51,10 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
option_value: _ListOrTuple[Tuple[str, str]],
|
||||
option_value: Optional[_OptionTogetherT],
|
||||
) -> None: ...
|
||||
@property
|
||||
def option_value(self) -> Set[Tuple[str, str]]: ...
|
||||
def option_value(self) -> Optional[Set[Tuple[str, ...]]]: ...
|
||||
def deconstruct(self) -> Tuple[str, Sequence[Any], Dict[str, Any]]: ...
|
||||
def state_forwards(self, app_label: str, state: Any) -> None: ...
|
||||
def database_forwards(
|
||||
@@ -72,26 +70,26 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
|
||||
|
||||
class AlterUniqueTogether(AlterTogetherOptionOperation):
|
||||
option_name: str = ...
|
||||
unique_together: _ListOrTuple[Tuple[str, str]] = ...
|
||||
def __init__(self, name: str, unique_together: _ListOrTuple[Tuple[str, str]]) -> None: ...
|
||||
unique_together: Optional[Set[Tuple[str, ...]]] = ...
|
||||
def __init__(self, name: str, unique_together: Optional[_OptionTogetherT]) -> None: ...
|
||||
|
||||
class AlterIndexTogether(AlterTogetherOptionOperation):
|
||||
option_name: str = ...
|
||||
index_together: _ListOrTuple[Tuple[str, str]] = ...
|
||||
def __init__(self, name: str, index_together: _ListOrTuple[Tuple[str, str]]) -> None: ...
|
||||
index_together: Optional[Set[Tuple[str, ...]]] = ...
|
||||
def __init__(self, name: str, index_together: Optional[_OptionTogetherT]) -> None: ...
|
||||
|
||||
class AlterOrderWithRespectTo(ModelOptionOperation):
|
||||
order_with_respect_to: str = ...
|
||||
def __init__(self, name: str, order_with_respect_to: str) -> None: ...
|
||||
|
||||
class AlterModelOptions(ModelOptionOperation):
|
||||
ALTER_OPTION_KEYS: Any = ...
|
||||
options: Dict[str, str] = ...
|
||||
ALTER_OPTION_KEYS: List[str] = ...
|
||||
options: Dict[str, Any] = ...
|
||||
def __init__(self, name: str, options: Dict[str, Any]) -> None: ...
|
||||
|
||||
class AlterModelManagers(ModelOptionOperation):
|
||||
managers: Any = ...
|
||||
def __init__(self, name: Any, managers: Any) -> None: ...
|
||||
managers: Sequence[Manager] = ...
|
||||
def __init__(self, name: str, managers: Sequence[Manager]) -> None: ...
|
||||
|
||||
class IndexOperation(Operation):
|
||||
option_name: str = ...
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import sys
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Mapping, Optional, Sequence, Union
|
||||
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.migrations.state import StateApps
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
from django.utils.datastructures import _ListOrTuple
|
||||
|
||||
from .base import Operation
|
||||
|
||||
@@ -25,14 +21,14 @@ class SeparateDatabaseAndState(Operation):
|
||||
|
||||
class RunSQL(Operation):
|
||||
noop: Literal[""] = ...
|
||||
sql: Union[str, List[str], Tuple[str, ...]] = ...
|
||||
reverse_sql: Optional[Union[str, List[str], Tuple[str, ...]]] = ...
|
||||
sql: Union[str, _ListOrTuple[str]] = ...
|
||||
reverse_sql: Optional[Union[str, _ListOrTuple[str]]] = ...
|
||||
state_operations: Sequence[Operation] = ...
|
||||
hints: Mapping[str, Any] = ...
|
||||
def __init__(
|
||||
self,
|
||||
sql: Union[str, List[str], Tuple[str, ...]],
|
||||
reverse_sql: Optional[Union[str, List[str], Tuple[str, ...]]] = ...,
|
||||
sql: Union[str, _ListOrTuple[str]],
|
||||
reverse_sql: Optional[Union[str, _ListOrTuple[str]]] = ...,
|
||||
state_operations: Sequence[Operation] = ...,
|
||||
hints: Optional[Mapping[str, Any]] = ...,
|
||||
elidable: bool = ...,
|
||||
@@ -44,13 +40,13 @@ class _CodeCallable(Protocol):
|
||||
class RunPython(Operation):
|
||||
code: _CodeCallable = ...
|
||||
reverse_code: Optional[_CodeCallable] = ...
|
||||
hints: Optional[Dict[str, Any]] = ...
|
||||
hints: Mapping[str, Any] = ...
|
||||
def __init__(
|
||||
self,
|
||||
code: _CodeCallable,
|
||||
reverse_code: Optional[_CodeCallable] = ...,
|
||||
atomic: Optional[bool] = ...,
|
||||
hints: Optional[Dict[str, Any]] = ...,
|
||||
hints: Optional[Mapping[str, Any]] = ...,
|
||||
elidable: bool = ...,
|
||||
) -> None: ...
|
||||
@staticmethod
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union
|
||||
import sys
|
||||
from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union, overload
|
||||
|
||||
from django.apps.config import AppConfig
|
||||
from django.apps.registry import Apps
|
||||
@@ -11,16 +12,26 @@ from django.db.models.fields.related import ManyToManyField, OneToOneField
|
||||
from django.db.models.fields.reverse_related import ForeignObjectRel
|
||||
from django.db.models.manager import Manager
|
||||
from django.db.models.query_utils import PathInfo
|
||||
from django.utils.datastructures import ImmutableList
|
||||
from django.utils.datastructures import ImmutableList, _ListOrTuple
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
PROXY_PARENTS: object
|
||||
EMPTY_RELATION_TREE: Any
|
||||
IMMUTABLE_WARNING: str
|
||||
DEFAULT_NAMES: Tuple[str, ...]
|
||||
|
||||
def normalize_together(
|
||||
option_together: Union[List[Tuple[str, str]], Tuple[Tuple[str, str], ...], Tuple[()], Tuple[str, str]]
|
||||
) -> Tuple[Tuple[str, str], ...]: ...
|
||||
_OptionTogetherT = Union[_ListOrTuple[Union[_ListOrTuple[str], str]], Set[Tuple[str, ...]]]
|
||||
|
||||
@overload
|
||||
def normalize_together(option_together: _ListOrTuple[Union[_ListOrTuple[str], str]]) -> Tuple[Tuple[str, ...], ...]: ...
|
||||
|
||||
# Any other value will be returned unchanged, but probably only set is semantically allowed
|
||||
@overload
|
||||
def normalize_together(option_together: Set[Tuple[str, ...]]) -> Set[Tuple[str, ...]]: ...
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
@@ -45,18 +56,18 @@ class Options(Generic[_M]):
|
||||
db_table: str = ...
|
||||
ordering: Optional[Sequence[str]] = ...
|
||||
indexes: List[Any] = ...
|
||||
unique_together: Union[Sequence[Tuple[str, str]], Tuple[str, str]] = ...
|
||||
index_together: Union[Sequence[Tuple[str, str]], Tuple[str, str]] = ...
|
||||
unique_together: Sequence[Tuple[str]] = ... # Are always normalized
|
||||
index_together: Sequence[Tuple[str]] = ... # Are always normalized
|
||||
select_on_save: bool = ...
|
||||
default_permissions: Sequence[str] = ...
|
||||
permissions: List[Any] = ...
|
||||
object_name: Optional[str] = ...
|
||||
app_label: str = ...
|
||||
get_latest_by: Optional[Sequence[str]] = ...
|
||||
order_with_respect_to: Optional[Any] = ...
|
||||
order_with_respect_to: Optional[str] = ...
|
||||
db_tablespace: str = ...
|
||||
required_db_features: List[Any] = ...
|
||||
required_db_vendor: Any = ...
|
||||
required_db_features: List[str] = ...
|
||||
required_db_vendor: Optional[Literal["sqlite", "postgresql", "mysql", "oracle"]] = ...
|
||||
meta: Optional[type] = ...
|
||||
pk: Optional[Field] = ...
|
||||
auto_field: Optional[AutoField] = ...
|
||||
@@ -105,7 +116,7 @@ class Options(Generic[_M]):
|
||||
def default_manager(self) -> Optional[Manager]: ...
|
||||
@property
|
||||
def fields(self) -> ImmutableList[Field[Any, Any]]: ...
|
||||
def get_field(self, field_name: str) -> Union[Field, ForeignObjectRel]: ...
|
||||
def get_field(self, field_name: str) -> Union[Field, ForeignObjectRel, GenericForeignKey]: ...
|
||||
def get_base_chain(self, model: Type[Model]) -> List[Type[Model]]: ...
|
||||
def get_parent_list(self) -> List[Type[Model]]: ...
|
||||
def get_ancestor_link(self, ancestor: Type[Model]) -> Optional[OneToOneField]: ...
|
||||
@@ -113,7 +124,7 @@ class Options(Generic[_M]):
|
||||
def get_path_from_parent(self, parent: Type[Model]) -> List[PathInfo]: ...
|
||||
def get_fields(
|
||||
self, include_parents: bool = ..., include_hidden: bool = ...
|
||||
) -> List[Union[Field[Any, Any], ForeignObjectRel]]: ...
|
||||
) -> List[Union[Field[Any, Any], ForeignObjectRel, GenericForeignKey]]: ...
|
||||
@property
|
||||
def total_unique_constraints(self) -> List[UniqueConstraint]: ...
|
||||
@property
|
||||
|
||||
@@ -45,7 +45,9 @@ class Q(tree.Node):
|
||||
AND: str = ...
|
||||
OR: str = ...
|
||||
conditional: bool = ...
|
||||
def __init__(self, *args: Any, _connector: Optional[Any] = ..., _negated: bool = ..., **kwargs: Any) -> None: ...
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
# Fake signature, the real is
|
||||
# def __init__(self, *args: Any, _connector: Optional[Any] = ..., _negated: bool = ..., **kwargs: Any) -> None: ...
|
||||
def __or__(self, other: Q) -> Q: ...
|
||||
def __and__(self, other: Q) -> Q: ...
|
||||
def __invert__(self) -> Q: ...
|
||||
|
||||
@@ -41,6 +41,8 @@ class ConnectionRouter:
|
||||
def __init__(self, routers: Optional[Iterable[Any]] = ...) -> None: ...
|
||||
@property
|
||||
def routers(self) -> List[Any]: ...
|
||||
def db_for_read(self, model: Type[Model], **hints: Any) -> str: ...
|
||||
def db_for_write(self, model: Type[Model], **hints: Any) -> str: ...
|
||||
def allow_relation(self, obj1: Model, obj2: Model, **hints: Any) -> bool: ...
|
||||
def allow_migrate(self, db: str, app_label: str, **hints: Any) -> bool: ...
|
||||
def allow_migrate_model(self, db: str, model: Type[Model]) -> bool: ...
|
||||
|
||||
@@ -3,11 +3,12 @@ from datetime import datetime
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.files import File
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
from django.utils.datastructures import MultiValueDict
|
||||
from django.utils.safestring import SafeString
|
||||
|
||||
_DataT = Mapping[str, Any]
|
||||
_FilesT = Mapping[str, Iterable[File]]
|
||||
_FilesT = MultiValueDict[str, UploadedFile]
|
||||
|
||||
def pretty_name(name: str) -> str: ...
|
||||
def flatatt(attrs: Dict[str, Any]) -> SafeString: ...
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import sys
|
||||
from io import BytesIO
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -6,6 +7,7 @@ from typing import (
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Pattern,
|
||||
Set,
|
||||
@@ -24,6 +26,11 @@ from django.core.files import uploadedfile, uploadhandler
|
||||
from django.urls import ResolverMatch
|
||||
from django.utils.datastructures import CaseInsensitiveMapping, ImmutableList, MultiValueDict
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Literal
|
||||
else:
|
||||
from typing import Literal
|
||||
|
||||
RAISE_ERROR: object = ...
|
||||
host_validation_re: Pattern[str] = ...
|
||||
|
||||
@@ -40,8 +47,8 @@ class HttpHeaders(CaseInsensitiveMapping[str]):
|
||||
def parse_header_name(cls, header: str) -> Optional[str]: ...
|
||||
|
||||
class HttpRequest(BytesIO):
|
||||
GET: QueryDict = ...
|
||||
POST: QueryDict = ...
|
||||
GET: _ImmutableQueryDict = ...
|
||||
POST: _ImmutableQueryDict = ...
|
||||
COOKIES: Dict[str, str] = ...
|
||||
META: Dict[str, Any] = ...
|
||||
FILES: MultiValueDict[str, uploadedfile.UploadedFile] = ...
|
||||
@@ -55,7 +62,15 @@ class HttpRequest(BytesIO):
|
||||
site: Site
|
||||
session: SessionBase
|
||||
_stream: BinaryIO
|
||||
def __init__(self) -> None: ...
|
||||
# The magic. If we instantiate HttpRequest directly somewhere, it has
|
||||
# mutable GET and POST. However, both ASGIRequest and WSGIRequest have immutable,
|
||||
# so when we use HttpRequest to refer to any of them we want exactly this.
|
||||
# Case when some function creates *exactly* HttpRequest (not subclass)
|
||||
# remain uncovered, however it's probably the best solution we can afford.
|
||||
def __new__(cls) -> _MutableHttpRequest: ...
|
||||
# When both __init__ and __new__ are present, mypy will prefer __init__
|
||||
# (see comments in mypy.checkmember.type_object_type)
|
||||
# def __init__(self) -> None: ...
|
||||
def get_host(self) -> str: ...
|
||||
def get_port(self) -> str: ...
|
||||
def get_full_path(self, force_append_slash: bool = ...) -> str: ...
|
||||
@@ -90,17 +105,43 @@ class HttpRequest(BytesIO):
|
||||
def _load_post_and_files(self) -> None: ...
|
||||
def accepts(self, media_type: str) -> bool: ...
|
||||
|
||||
class _MutableHttpRequest(HttpRequest):
|
||||
GET: QueryDict = ... # type: ignore[assignment]
|
||||
POST: QueryDict = ... # type: ignore[assignment]
|
||||
|
||||
_Q = TypeVar("_Q", bound="QueryDict")
|
||||
_Z = TypeVar("_Z")
|
||||
|
||||
class QueryDict(MultiValueDict[str, str]):
|
||||
_mutable: bool = ...
|
||||
# We can make it mutable only by specifying `mutable=True`.
|
||||
# It can be done a) with kwarg and b) with pos. arg. `overload` has
|
||||
# some problems with args/kwargs + Literal, so two signatures are required.
|
||||
# ('querystring', True, [...])
|
||||
@overload
|
||||
def __init__(
|
||||
self, query_string: Optional[Union[str, bytes]] = ..., mutable: bool = ..., encoding: Optional[str] = ...
|
||||
self: QueryDict,
|
||||
query_string: Optional[Union[str, bytes]],
|
||||
mutable: Literal[True],
|
||||
encoding: Optional[str] = ...,
|
||||
) -> None: ...
|
||||
# ([querystring='string',] mutable=True, [...])
|
||||
@overload
|
||||
def __init__(
|
||||
self: QueryDict,
|
||||
*,
|
||||
mutable: Literal[True],
|
||||
query_string: Optional[Union[str, bytes]] = ...,
|
||||
encoding: Optional[str] = ...,
|
||||
) -> None: ...
|
||||
# Otherwise it's immutable
|
||||
@overload
|
||||
def __init__( # type: ignore[misc]
|
||||
self: _ImmutableQueryDict,
|
||||
query_string: Optional[Union[str, bytes]] = ...,
|
||||
mutable: bool = ...,
|
||||
encoding: Optional[str] = ...,
|
||||
) -> None: ...
|
||||
def setlist(self, key: Union[str, bytes], list_: Iterable[Union[str, bytes]]) -> None: ...
|
||||
def setlistdefault(self, key: Union[str, bytes], default_list: Optional[List[str]] = ...) -> List[str]: ...
|
||||
def appendlist(self, key: Union[str, bytes], value: Union[str, bytes]) -> None: ...
|
||||
def urlencode(self, safe: Optional[str] = ...) -> str: ...
|
||||
@classmethod
|
||||
def fromkeys( # type: ignore
|
||||
cls: Type[_Q],
|
||||
@@ -114,7 +155,46 @@ class QueryDict(MultiValueDict[str, str]):
|
||||
@encoding.setter
|
||||
def encoding(self, value: str) -> None: ...
|
||||
def __setitem__(self, key: Union[str, bytes], value: Union[str, bytes]) -> None: ...
|
||||
def __delitem__(self, key: Union[str, bytes]) -> None: ...
|
||||
def setlist(self, key: Union[str, bytes], list_: Iterable[Union[str, bytes]]) -> None: ...
|
||||
def setlistdefault(self, key: Union[str, bytes], default_list: Optional[List[str]] = ...) -> List[str]: ...
|
||||
def appendlist(self, key: Union[str, bytes], value: Union[str, bytes]) -> None: ...
|
||||
# Fake signature (because *args is used in source, but it fails with more that 1 argument)
|
||||
@overload
|
||||
def pop(self, key: Union[str, bytes], /) -> str: ...
|
||||
@overload
|
||||
def pop(self, key: Union[str, bytes], default: Union[str, _Z] = ..., /) -> Union[str, _Z]: ...
|
||||
def popitem(self) -> Tuple[str, str]: ...
|
||||
def clear(self) -> None: ...
|
||||
def setdefault(self, key: Union[str, bytes], default: Union[str, bytes, None] = ...) -> str: ...
|
||||
def copy(self) -> QueryDict: ...
|
||||
def urlencode(self, safe: Optional[str] = ...) -> str: ...
|
||||
|
||||
class _ImmutableQueryDict(QueryDict):
|
||||
_mutable: Literal[False]
|
||||
# def __init__(
|
||||
# self, query_string: Optional[Union[str, bytes]] = ..., mutable: bool = ..., encoding: Optional[str] = ...
|
||||
# ) -> None: ...
|
||||
def __setitem__(self, key: Union[str, bytes], value: Union[str, bytes]) -> NoReturn: ...
|
||||
def __delitem__(self, key: Union[str, bytes]) -> NoReturn: ...
|
||||
def setlist(self, key: Union[str, bytes], list_: Iterable[Union[str, bytes]]) -> NoReturn: ...
|
||||
def setlistdefault(self, key: Union[str, bytes], default_list: Optional[List[str]] = ...) -> NoReturn: ...
|
||||
def appendlist(self, key: Union[str, bytes], value: Union[str, bytes]) -> NoReturn: ...
|
||||
# Fake signature (because *args is used in source, but it fails with more that 1 argument)
|
||||
@overload
|
||||
def pop(self, key: Union[str, bytes], /) -> NoReturn: ...
|
||||
@overload
|
||||
def pop(self, key: Union[str, bytes], default: Union[str, _Z] = ..., /) -> NoReturn: ...
|
||||
def popitem(self) -> NoReturn: ...
|
||||
def clear(self) -> NoReturn: ...
|
||||
def setdefault(self, key: Union[str, bytes], default: Union[str, bytes, None] = ...) -> NoReturn: ...
|
||||
def copy(self) -> QueryDict: ... # type: ignore[override]
|
||||
def urlencode(self, safe: Optional[str] = ...) -> str: ...
|
||||
# Fakes for convenience (for `request.GET` and `request.POST`). If dict
|
||||
# was created by Django, there is no chance to hit `List[object]` (empty list)
|
||||
# edge case.
|
||||
def __getitem__(self, key: str) -> str: ...
|
||||
def dict(self) -> Dict[str, str]: ... # type: ignore[override]
|
||||
|
||||
class MediaType:
|
||||
def __init__(self, media_type_raw_line: str) -> None: ...
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import datetime
|
||||
from io import BytesIO
|
||||
from json import JSONEncoder
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, overload
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, overload, type_check_only
|
||||
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.http.cookie import SimpleCookie
|
||||
from django.template import Context, Template
|
||||
from django.test.client import Client
|
||||
from django.urls import ResolverMatch
|
||||
from django.utils.datastructures import CaseInsensitiveMapping, _PropertyDescriptor
|
||||
|
||||
class BadHeaderError(ValueError): ...
|
||||
@@ -82,6 +78,12 @@ class HttpResponseBase:
|
||||
def writable(self) -> bool: ...
|
||||
def writelines(self, lines: Iterable[object]) -> None: ...
|
||||
|
||||
# Fake methods that are implemented by all subclasses
|
||||
@type_check_only
|
||||
def __iter__(self) -> Iterator[bytes]: ...
|
||||
@type_check_only
|
||||
def getvalue(self) -> bytes: ...
|
||||
|
||||
class HttpResponse(HttpResponseBase, Iterable[bytes]):
|
||||
content = _PropertyDescriptor[object, bytes]()
|
||||
csrf_cookie_set: bool
|
||||
@@ -94,14 +96,6 @@ class HttpResponse(HttpResponseBase, Iterable[bytes]):
|
||||
def serialize(self) -> bytes: ...
|
||||
__bytes__ = serialize
|
||||
def __iter__(self) -> Iterator[bytes]: ...
|
||||
# Attributes assigned by monkey-patching in test client ClientHandler.__call__()
|
||||
wsgi_request: WSGIRequest
|
||||
# Attributes assigned by monkey-patching in test client Client.request()
|
||||
client: Client
|
||||
request: Dict[str, Any]
|
||||
templates: List[Template]
|
||||
context: Context
|
||||
resolver_match: ResolverMatch
|
||||
def getvalue(self) -> bytes: ...
|
||||
|
||||
class StreamingHttpResponse(HttpResponseBase, Iterable[bytes]):
|
||||
@@ -111,13 +105,7 @@ class StreamingHttpResponse(HttpResponseBase, Iterable[bytes]):
|
||||
def getvalue(self) -> bytes: ...
|
||||
|
||||
class FileResponse(StreamingHttpResponse):
|
||||
client: Client
|
||||
context: None
|
||||
file_to_stream: Optional[BytesIO]
|
||||
request: Dict[str, str]
|
||||
resolver_match: ResolverMatch
|
||||
templates: List[Any]
|
||||
wsgi_request: WSGIRequest
|
||||
block_size: int = ...
|
||||
as_attachment: bool = ...
|
||||
filename: str = ...
|
||||
|
||||
@@ -3,6 +3,7 @@ from json import JSONEncoder
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
@@ -26,6 +27,8 @@ from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.http.cookie import SimpleCookie
|
||||
from django.http.request import HttpRequest
|
||||
from django.http.response import HttpResponseBase
|
||||
from django.template.base import Template
|
||||
from django.urls import ResolverMatch
|
||||
|
||||
BOUNDARY: str = ...
|
||||
MULTIPART_CONTENT: str = ...
|
||||
@@ -49,27 +52,25 @@ _T = TypeVar("_T")
|
||||
def closing_iterator_wrapper(iterable: Iterable[_T], close: Callable[[], Any]) -> Iterator[_T]: ...
|
||||
def conditional_content_removal(request: HttpRequest, response: HttpResponseBase) -> HttpResponseBase: ...
|
||||
|
||||
class _WSGIResponse(HttpResponseBase):
|
||||
wsgi_request: WSGIRequest
|
||||
|
||||
class _ASGIResponse(HttpResponseBase):
|
||||
asgi_request: ASGIRequest
|
||||
|
||||
class ClientHandler(BaseHandler):
|
||||
enforce_csrf_checks: bool = ...
|
||||
def __init__(self, enforce_csrf_checks: bool = ..., *args: Any, **kwargs: Any) -> None: ...
|
||||
def __call__(self, environ: Dict[str, Any]) -> HttpResponseBase: ...
|
||||
def __call__(self, environ: Dict[str, Any]) -> _WSGIResponse: ...
|
||||
|
||||
class AsyncClientHandler(BaseHandler):
|
||||
enforce_csrf_checks: bool = ...
|
||||
def __init__(self, enforce_csrf_checks: bool = ..., *args: Any, **kwargs: Any) -> None: ...
|
||||
async def __call__(self, scope: Dict[str, Any]) -> HttpResponseBase: ...
|
||||
async def __call__(self, scope: Dict[str, Any]) -> _ASGIResponse: ...
|
||||
|
||||
def encode_multipart(boundary: str, data: Dict[str, Any]) -> bytes: ...
|
||||
def encode_file(boundary: str, key: str, file: Any) -> List[bytes]: ...
|
||||
|
||||
# fake to distinguish WSGIRequest and ASGIRequest
|
||||
|
||||
_R = TypeVar("_R", bound=HttpRequest)
|
||||
|
||||
class _MonkeyPatchedHttpResponseBase(Generic[_R], HttpResponseBase):
|
||||
def json(self) -> Any: ...
|
||||
wsgi_request: _R
|
||||
|
||||
class _RequestFactory(Generic[_T]):
|
||||
json_encoder: Type[JSONEncoder]
|
||||
defaults: Dict[str, str]
|
||||
@@ -102,9 +103,9 @@ class _RequestFactory(Generic[_T]):
|
||||
**extra: Any
|
||||
) -> _T: ...
|
||||
|
||||
class RequestFactory(_RequestFactory[_T]): ...
|
||||
class RequestFactory(_RequestFactory[WSGIRequest]): ...
|
||||
|
||||
class AsyncRequestFactory(_RequestFactory[_T]):
|
||||
class _AsyncRequestFactory(_RequestFactory[_T]):
|
||||
def request(self, **request: Any) -> _T: ...
|
||||
def generic(
|
||||
self,
|
||||
@@ -116,6 +117,25 @@ class AsyncRequestFactory(_RequestFactory[_T]):
|
||||
**extra: Any
|
||||
) -> _T: ...
|
||||
|
||||
class AsyncRequestFactory(_AsyncRequestFactory[ASGIRequest]): ...
|
||||
|
||||
# fakes to distinguish WSGIRequest and ASGIRequest
|
||||
class _MonkeyPatchedWSGIResponse(_WSGIResponse):
|
||||
def json(self) -> Any: ...
|
||||
request: Dict[str, Any]
|
||||
client: Client
|
||||
templates: List[Template]
|
||||
context: List[Dict[str, Any]]
|
||||
resolver_match: ResolverMatch
|
||||
|
||||
class _MonkeyPatchedASGIResponse(_ASGIResponse):
|
||||
def json(self) -> Any: ...
|
||||
request: Dict[str, Any]
|
||||
client: AsyncClient
|
||||
templates: List[Template]
|
||||
context: List[Dict[str, Any]]
|
||||
resolver_match: ResolverMatch
|
||||
|
||||
class ClientMixin:
|
||||
def store_exc_info(self, **kwargs: Any) -> None: ...
|
||||
def check_exception(self, response: HttpResponseBase) -> NoReturn: ...
|
||||
@@ -125,7 +145,7 @@ class ClientMixin:
|
||||
def force_login(self, user: AbstractBaseUser, backend: Optional[str] = ...) -> None: ...
|
||||
def logout(self) -> None: ...
|
||||
|
||||
class Client(ClientMixin, RequestFactory[_MonkeyPatchedHttpResponseBase[WSGIRequest]]):
|
||||
class Client(ClientMixin, _RequestFactory[_MonkeyPatchedWSGIResponse]):
|
||||
handler: ClientHandler
|
||||
raise_request_exception: bool
|
||||
exc_info: Optional[Tuple[Type[BaseException], BaseException, TracebackType]]
|
||||
@@ -133,19 +153,19 @@ class Client(ClientMixin, RequestFactory[_MonkeyPatchedHttpResponseBase[WSGIRequ
|
||||
self, enforce_csrf_checks: bool = ..., raise_request_exception: bool = ..., **defaults: Any
|
||||
) -> None: ...
|
||||
# Silence type warnings, since this class overrides arguments and return types in an unsafe manner.
|
||||
def request(self, **request: Any) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
|
||||
def request(self, **request: Any) -> _MonkeyPatchedWSGIResponse: ...
|
||||
def get( # type: ignore
|
||||
self, path: str, data: Any = ..., follow: bool = ..., secure: bool = ..., **extra: Any
|
||||
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
|
||||
) -> _MonkeyPatchedWSGIResponse: ...
|
||||
def post( # type: ignore
|
||||
self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any
|
||||
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
|
||||
) -> _MonkeyPatchedWSGIResponse: ...
|
||||
def head( # type: ignore
|
||||
self, path: str, data: Any = ..., follow: bool = ..., secure: bool = ..., **extra: Any
|
||||
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
|
||||
) -> _MonkeyPatchedWSGIResponse: ...
|
||||
def trace( # type: ignore
|
||||
self, path: str, data: Any = ..., follow: bool = ..., secure: bool = ..., **extra: Any
|
||||
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
|
||||
) -> _MonkeyPatchedWSGIResponse: ...
|
||||
def options( # type: ignore
|
||||
self,
|
||||
path: str,
|
||||
@@ -154,18 +174,18 @@ class Client(ClientMixin, RequestFactory[_MonkeyPatchedHttpResponseBase[WSGIRequ
|
||||
follow: bool = ...,
|
||||
secure: bool = ...,
|
||||
**extra: Any
|
||||
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
|
||||
) -> _MonkeyPatchedWSGIResponse: ...
|
||||
def put( # type: ignore
|
||||
self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any
|
||||
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
|
||||
) -> _MonkeyPatchedWSGIResponse: ...
|
||||
def patch( # type: ignore
|
||||
self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any
|
||||
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
|
||||
) -> _MonkeyPatchedWSGIResponse: ...
|
||||
def delete( # type: ignore
|
||||
self, path: str, data: Any = ..., content_type: str = ..., follow: bool = ..., secure: bool = ..., **extra: Any
|
||||
) -> _MonkeyPatchedHttpResponseBase[WSGIRequest]: ...
|
||||
) -> _MonkeyPatchedWSGIResponse: ...
|
||||
|
||||
class AsyncClient(ClientMixin, AsyncRequestFactory[_MonkeyPatchedHttpResponseBase[ASGIRequest]]):
|
||||
class AsyncClient(ClientMixin, _AsyncRequestFactory[Awaitable[_MonkeyPatchedASGIResponse]]):
|
||||
handler: AsyncClientHandler
|
||||
raise_request_exception: bool
|
||||
exc_info: Any
|
||||
@@ -173,4 +193,4 @@ class AsyncClient(ClientMixin, AsyncRequestFactory[_MonkeyPatchedHttpResponseBas
|
||||
def __init__(
|
||||
self, enforce_csrf_checks: bool = ..., raise_request_exception: bool = ..., **defaults: Any
|
||||
) -> None: ...
|
||||
async def request(self, **request: Any) -> _MonkeyPatchedHttpResponseBase[ASGIRequest]: ... # type: ignore
|
||||
async def request(self, **request: Any) -> _MonkeyPatchedASGIResponse: ...
|
||||
|
||||
@@ -28,7 +28,7 @@ from django.db.models.query import QuerySet, RawQuerySet
|
||||
from django.forms.fields import EmailField
|
||||
from django.http.response import HttpResponse, HttpResponseBase
|
||||
from django.template.base import Template
|
||||
from django.test.client import Client
|
||||
from django.test.client import AsyncClient, Client
|
||||
from django.test.html import Element
|
||||
from django.test.utils import CaptureQueriesContext, ContextList
|
||||
from django.utils.functional import classproperty
|
||||
@@ -64,8 +64,10 @@ class _DatabaseFailure:
|
||||
def __call__(self) -> None: ...
|
||||
|
||||
class SimpleTestCase(unittest.TestCase):
|
||||
client_class: Any = ...
|
||||
client_class: Type[Client] = ...
|
||||
client: Client
|
||||
async_client_class: Type[AsyncClient] = ...
|
||||
async_client: AsyncClient
|
||||
allow_database_queries: bool = ...
|
||||
# TODO: str -> Literal['__all__']
|
||||
databases: Union[Set[str], str] = ...
|
||||
@@ -142,9 +144,7 @@ class SimpleTestCase(unittest.TestCase):
|
||||
) -> Any: ...
|
||||
def assertHTMLEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ...
|
||||
def assertHTMLNotEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ...
|
||||
def assertInHTML(
|
||||
self, needle: str, haystack: SafeString, count: Optional[int] = ..., msg_prefix: str = ...
|
||||
) -> None: ...
|
||||
def assertInHTML(self, needle: str, haystack: str, count: Optional[int] = ..., msg_prefix: str = ...) -> None: ...
|
||||
def assertJSONEqual(
|
||||
self,
|
||||
raw: str,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Pattern, Tuple, Type, Union, overload
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Dict, List, Optional, Pattern, Sequence, Tuple, Type, Union, overload
|
||||
|
||||
from django.core.checks.messages import CheckMessage
|
||||
from django.urls.converters import UUIDConverter
|
||||
@@ -93,7 +94,7 @@ class URLPattern:
|
||||
|
||||
class URLResolver:
|
||||
pattern: LocalePrefixPattern = ...
|
||||
urlconf_name: Optional[str] = ...
|
||||
urlconf_name: Union[str, None, Sequence[Union[URLPattern, URLResolver]]] = ...
|
||||
callback: None = ...
|
||||
default_kwargs: Dict[str, Any] = ...
|
||||
namespace: Optional[str] = ...
|
||||
@@ -103,7 +104,7 @@ class URLResolver:
|
||||
def __init__(
|
||||
self,
|
||||
pattern: LocalePrefixPattern,
|
||||
urlconf_name: Optional[str],
|
||||
urlconf_name: Union[str, None, Sequence[Union[URLPattern, URLResolver]]],
|
||||
default_kwargs: Optional[Dict[str, Any]] = ...,
|
||||
app_name: Optional[str] = ...,
|
||||
namespace: Optional[str] = ...,
|
||||
@@ -115,7 +116,7 @@ class URLResolver:
|
||||
@property
|
||||
def app_dict(self) -> Dict[str, List[str]]: ...
|
||||
@property
|
||||
def urlconf_module(self) -> Optional[List[Tuple[str, Callable]]]: ...
|
||||
def urlconf_module(self) -> Union[ModuleType, None, Sequence[Union[URLPattern, URLResolver]]]: ...
|
||||
@property
|
||||
def url_patterns(self) -> List[Union[URLPattern, URLResolver]]: ...
|
||||
def resolve(self, path: str) -> ResolverMatch: ...
|
||||
|
||||
@@ -72,7 +72,7 @@ class classproperty(Generic[_Get]):
|
||||
fget: Optional[Callable[[_Self], _Get]] = ...
|
||||
def __init__(self, method: Optional[Callable[[_Self], _Get]] = ...) -> None: ...
|
||||
def __get__(self, instance: Optional[_Self], cls: Type[_Self] = ...) -> _Get: ...
|
||||
def getter(self, method: Callable[[Type[_Self]], _Get]) -> classproperty[_Get]: ...
|
||||
def getter(self, method: Callable[[_Self], _Get]) -> classproperty[_Get]: ...
|
||||
|
||||
class _Getter(Protocol[_Get]):
|
||||
"""Type fake to declare some read-only properties (until `property` builtin is generic)
|
||||
|
||||
@@ -26,7 +26,7 @@ class FormMixin(Generic[_FormT], ContextMixin):
|
||||
def get_initial(self) -> Dict[str, Any]: ...
|
||||
def get_prefix(self) -> Optional[str]: ...
|
||||
def get_form_class(self) -> Type[_FormT]: ...
|
||||
def get_form(self, form_class: Optional[Type[_FormT]] = ...) -> BaseForm: ...
|
||||
def get_form(self, form_class: Optional[Type[_FormT]] = ...) -> _FormT: ...
|
||||
def get_form_kwargs(self) -> Dict[str, Any]: ...
|
||||
def get_success_url(self) -> str: ...
|
||||
def form_valid(self, form: _FormT) -> HttpResponse: ...
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
from typing import Any, Protocol
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing_extensions import Protocol
|
||||
else:
|
||||
from typing import Protocol
|
||||
|
||||
# Used internally by mypy_django_plugin.
|
||||
class AnyAttrAllowed(Protocol):
|
||||
|
||||
1
mypy.ini
1
mypy.ini
@@ -9,6 +9,7 @@ warn_no_return = False
|
||||
warn_unused_ignores = True
|
||||
warn_redundant_casts = True
|
||||
warn_unused_configs = True
|
||||
warn_unreachable = True
|
||||
|
||||
plugins =
|
||||
mypy_django_plugin.main
|
||||
|
||||
@@ -35,6 +35,7 @@ except ImportError:
|
||||
if TYPE_CHECKING:
|
||||
from django.apps.registry import Apps # noqa: F401
|
||||
from django.conf import LazySettings # noqa: F401
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -100,9 +101,6 @@ class DjangoContext:
|
||||
@cached_property
|
||||
def model_modules(self) -> Dict[str, Set[Type[Model]]]:
|
||||
"""All modules that contain Django models."""
|
||||
if self.apps_registry is None:
|
||||
return {}
|
||||
|
||||
modules: Dict[str, Set[Type[Model]]] = defaultdict(set)
|
||||
for concrete_model_cls in self.apps_registry.get_models():
|
||||
modules[concrete_model_cls.__module__].add(concrete_model_cls)
|
||||
@@ -327,7 +325,7 @@ class DjangoContext:
|
||||
related_model_cls = field.field.model
|
||||
|
||||
if isinstance(related_model_cls, str):
|
||||
if related_model_cls == "self":
|
||||
if related_model_cls == "self": # type: ignore[unreachable]
|
||||
# same model
|
||||
related_model_cls = field.model
|
||||
elif "." not in related_model_cls:
|
||||
@@ -343,7 +341,7 @@ class DjangoContext:
|
||||
self, field_parts: Iterable[str], model_cls: Type[Model]
|
||||
) -> Union[Field, ForeignObjectRel]:
|
||||
currently_observed_model = model_cls
|
||||
field: Union[Field, ForeignObjectRel, None] = None
|
||||
field: Union[Field, ForeignObjectRel, GenericForeignKey, None] = None
|
||||
for field_part in field_parts:
|
||||
if field_part == "pk":
|
||||
field = self.get_primary_key_field(currently_observed_model)
|
||||
@@ -359,15 +357,16 @@ class DjangoContext:
|
||||
if isinstance(field, ForeignObjectRel):
|
||||
currently_observed_model = field.related_model
|
||||
|
||||
assert field is not None
|
||||
# Guaranteed by `query.solve_lookup_type` before.
|
||||
assert isinstance(field, (Field, ForeignObjectRel))
|
||||
return field
|
||||
|
||||
def resolve_lookup_into_field(self, model_cls: Type[Model], lookup: str) -> Union[Field, ForeignObjectRel]:
|
||||
query = Query(model_cls)
|
||||
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
|
||||
|
||||
if lookup_parts:
|
||||
raise LookupsAreUnsupported()
|
||||
|
||||
return self._resolve_field_from_parts(field_parts, model_cls)
|
||||
|
||||
def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType:
|
||||
|
||||
@@ -35,6 +35,7 @@ RELATED_FIELDS_CLASSES = {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME, MANYTOM
|
||||
MIGRATION_CLASS_FULLNAME = "django.db.migrations.migration.Migration"
|
||||
OPTIONS_CLASS_FULLNAME = "django.db.models.options.Options"
|
||||
HTTPREQUEST_CLASS_FULLNAME = "django.http.request.HttpRequest"
|
||||
QUERYDICT_CLASS_FULLNAME = "django.http.request.QueryDict"
|
||||
|
||||
COMBINABLE_EXPRESSION_FULLNAME = "django.db.models.expressions.Combinable"
|
||||
F_EXPRESSION_FULLNAME = "django.db.models.expressions.F"
|
||||
|
||||
@@ -32,6 +32,7 @@ from mypy_django_plugin.transformers.models import (
|
||||
process_model_class,
|
||||
set_auth_user_model_boolean_fields,
|
||||
)
|
||||
from mypy_django_plugin.transformers.request import check_querydict_is_mutable
|
||||
|
||||
|
||||
def transform_model_class(ctx: ClassDefContext, django_context: DjangoContext) -> None:
|
||||
@@ -187,12 +188,21 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
|
||||
def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], MypyType]]:
|
||||
class_fullname, _, method_name = fullname.rpartition(".")
|
||||
if method_name == "get_form_class":
|
||||
# It is looked up very often, specialcase this method for minor speed up
|
||||
if method_name == "__init_subclass__":
|
||||
return None
|
||||
|
||||
if class_fullname.endswith("QueryDict"):
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.QUERYDICT_CLASS_FULLNAME):
|
||||
return partial(check_querydict_is_mutable, django_context=self.django_context)
|
||||
|
||||
elif method_name == "get_form_class":
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
|
||||
return forms.extract_proper_type_for_get_form_class
|
||||
|
||||
if method_name == "get_form":
|
||||
elif method_name == "get_form":
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
|
||||
return forms.extract_proper_type_for_get_form
|
||||
@@ -204,30 +214,30 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
|
||||
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
|
||||
|
||||
if method_name == "values_list":
|
||||
elif method_name == "values_list":
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
|
||||
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
|
||||
|
||||
if method_name == "annotate":
|
||||
elif method_name == "annotate":
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
|
||||
return partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context)
|
||||
|
||||
if method_name == "get_field":
|
||||
elif method_name == "get_field":
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
|
||||
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)
|
||||
|
||||
if class_fullname in manager_classes and method_name == "create":
|
||||
elif class_fullname in manager_classes and method_name == "create":
|
||||
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
|
||||
if class_fullname in manager_classes and method_name in {"filter", "get", "exclude"}:
|
||||
elif class_fullname in manager_classes and method_name in {"filter", "get", "exclude"}:
|
||||
return partial(
|
||||
mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter,
|
||||
django_context=self.django_context,
|
||||
)
|
||||
|
||||
if method_name == "from_queryset":
|
||||
elif method_name == "from_queryset":
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
|
||||
return fail_if_manager_type_created_in_model_body
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple, Union, cast
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
||||
|
||||
from django.db.models.fields import AutoField, Field
|
||||
from django.db.models.fields.related import RelatedField
|
||||
@@ -12,10 +12,13 @@ from mypy.types import TypeOfAny, UnionType
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
|
||||
|
||||
def _get_current_field_from_assignment(
|
||||
ctx: FunctionContext, django_context: DjangoContext
|
||||
) -> Optional[Union[Field, ForeignObjectRel]]:
|
||||
) -> Optional[Union[Field, ForeignObjectRel, "GenericForeignKey"]]:
|
||||
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
|
||||
if outer_model_info is None or not helpers.is_model_subclass_info(outer_model_info, django_context):
|
||||
return None
|
||||
|
||||
@@ -263,8 +263,6 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
|
||||
semanal_api.defer()
|
||||
return None
|
||||
original_return_type = method_type.ret_type
|
||||
if original_return_type is None:
|
||||
continue
|
||||
|
||||
# Skip any method that doesn't return _QS
|
||||
original_return_type = get_proper_type(original_return_type)
|
||||
|
||||
@@ -15,7 +15,7 @@ from mypy.types import TypedDictType, TypeOfAny
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME
|
||||
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME, MODEL_CLASS_FULLNAME
|
||||
from mypy_django_plugin.lib.helpers import add_new_class_for_module
|
||||
from mypy_django_plugin.transformers import fields
|
||||
from mypy_django_plugin.transformers.fields import get_field_descriptor_types
|
||||
@@ -475,8 +475,8 @@ def handle_annotated_type(ctx: AnalyzeTypeContext, django_context: DjangoContext
|
||||
type_arg = ctx.api.analyze_type(args[0])
|
||||
api = cast(SemanticAnalyzer, ctx.api.api) # type: ignore
|
||||
|
||||
if not isinstance(type_arg, Instance):
|
||||
return ctx.api.analyze_type(ctx.type)
|
||||
if not isinstance(type_arg, Instance) or not type_arg.type.has_base(MODEL_CLASS_FULLNAME):
|
||||
return type_arg
|
||||
|
||||
fields_dict = None
|
||||
if len(args) > 1:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from mypy.plugin import AttributeContext
|
||||
from mypy.plugin import AttributeContext, MethodContext
|
||||
from mypy.types import Instance
|
||||
from mypy.types import Type as MypyType
|
||||
from mypy.types import UnionType
|
||||
from mypy.types import UninhabitedType, UnionType
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import helpers
|
||||
@@ -35,3 +35,10 @@ def set_auth_user_model_as_type_for_request_user(ctx: AttributeContext, django_c
|
||||
return ctx.default_attr_type
|
||||
|
||||
return UnionType([Instance(user_info, []), Instance(anonymous_user_info, [])])
|
||||
|
||||
|
||||
def check_querydict_is_mutable(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||
ret_type = ctx.default_return_type
|
||||
if isinstance(ret_type, UninhabitedType):
|
||||
ctx.api.fail("This QueryDict is immutable.", ctx.context)
|
||||
return ret_type
|
||||
|
||||
33
tests/typecheck/test/test_client.yml
Normal file
33
tests/typecheck/test/test_client.yml
Normal file
@@ -0,0 +1,33 @@
|
||||
- case: client_methods
|
||||
main: |
|
||||
from django.test.client import Client
|
||||
client = Client()
|
||||
response = client.get('foo')
|
||||
reveal_type(response.wsgi_request) # N: Revealed type is "django.core.handlers.wsgi.WSGIRequest"
|
||||
reveal_type(response.request) # N: Revealed type is "builtins.dict[builtins.str, Any]"
|
||||
reveal_type(response.templates) # N: Revealed type is "builtins.list[django.template.base.Template]"
|
||||
reveal_type(response.client) # N: Revealed type is "django.test.client.Client"
|
||||
reveal_type(response.context) # N: Revealed type is "builtins.list[builtins.dict[builtins.str, Any]]"
|
||||
response.json()
|
||||
- case: async_client_methods
|
||||
main: |
|
||||
from django.test.client import AsyncClient
|
||||
async def main():
|
||||
client = AsyncClient()
|
||||
response = await client.get('foo')
|
||||
reveal_type(response.asgi_request) # N: Revealed type is "django.core.handlers.asgi.ASGIRequest"
|
||||
reveal_type(response.request) # N: Revealed type is "builtins.dict[builtins.str, Any]"
|
||||
reveal_type(response.templates) # N: Revealed type is "builtins.list[django.template.base.Template]"
|
||||
reveal_type(response.client) # N: Revealed type is "django.test.client.AsyncClient"
|
||||
reveal_type(response.context) # N: Revealed type is "builtins.list[builtins.dict[builtins.str, Any]]"
|
||||
response.json()
|
||||
- case: request_factories
|
||||
main: |
|
||||
from django.test.client import RequestFactory, AsyncRequestFactory
|
||||
factory = RequestFactory()
|
||||
request = factory.get('foo')
|
||||
reveal_type(request) # N: Revealed type is "django.core.handlers.wsgi.WSGIRequest*"
|
||||
|
||||
async_factory = AsyncRequestFactory()
|
||||
async_request = async_factory.get('foo')
|
||||
reveal_type(async_request) # N: Revealed type is "django.core.handlers.asgi.ASGIRequest*"
|
||||
@@ -9,4 +9,4 @@
|
||||
reveal_type(resp.status_code) # N: Revealed type is "builtins.int"
|
||||
# Attributes monkey-patched by test Client class:
|
||||
resp.json()
|
||||
reveal_type(resp.wsgi_request) # N: Revealed type is "django.core.handlers.wsgi.WSGIRequest*"
|
||||
reveal_type(resp.wsgi_request) # N: Revealed type is "django.core.handlers.wsgi.WSGIRequest"
|
||||
|
||||
17
tests/typecheck/test_annotated.yml
Normal file
17
tests/typecheck/test_annotated.yml
Normal file
@@ -0,0 +1,17 @@
|
||||
# Regression test for #893
|
||||
- case: annotated_should_not_iterfere
|
||||
main: |
|
||||
from dataclasses import dataclass
|
||||
import sys
|
||||
if sys.version_info < (3, 9):
|
||||
from typing_extensions import Annotated
|
||||
else:
|
||||
from typing import Annotated
|
||||
|
||||
class IntegerType:
|
||||
def __init__(self, min_value: int, max_value: int) -> None:
|
||||
pass
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class RatingComposite:
|
||||
max_value: Annotated[int, IntegerType(min_value=1, max_value=10)] = 5
|
||||
@@ -69,3 +69,41 @@
|
||||
reveal_type(request.user) # N: Revealed type is "django.contrib.auth.models.User"
|
||||
custom_settings: |
|
||||
INSTALLED_APPS = ('django.contrib.contenttypes', 'django.contrib.auth')
|
||||
|
||||
- case: request_get_post
|
||||
main: |
|
||||
from django.http.request import HttpRequest
|
||||
|
||||
request = HttpRequest()
|
||||
reveal_type(request) # N: Revealed type is "django.http.request._MutableHttpRequest"
|
||||
reveal_type(request.GET) # N: Revealed type is "django.http.request.QueryDict"
|
||||
request.GET['foo'] = 'bar'
|
||||
|
||||
def mk_request() -> HttpRequest:
|
||||
return HttpRequest()
|
||||
|
||||
req = mk_request()
|
||||
reveal_type(req) # N: Revealed type is "django.http.request.HttpRequest"
|
||||
reveal_type(req.GET) # N: Revealed type is "django.http.request._ImmutableQueryDict"
|
||||
req.GET.setdefault('foo', 'bar') # E: This QueryDict is immutable.
|
||||
x = 1 # E: Statement is unreachable
|
||||
|
||||
# Will work after merging https://github.com/python/mypy/pull/12572
|
||||
- case: request_get_post_unreachable
|
||||
main: |
|
||||
from django.http.request import HttpRequest
|
||||
|
||||
request = HttpRequest()
|
||||
reveal_type(request) # N: Revealed type is "django.http.request._MutableHttpRequest"
|
||||
reveal_type(request.GET) # N: Revealed type is "django.http.request.QueryDict"
|
||||
request.GET['foo'] = 'bar'
|
||||
|
||||
def mk_request() -> HttpRequest:
|
||||
return HttpRequest()
|
||||
|
||||
req = mk_request()
|
||||
reveal_type(req) # N: Revealed type is "django.http.request.HttpRequest"
|
||||
reveal_type(req.GET) # N: Revealed type is "django.http.request._ImmutableQueryDict"
|
||||
req.GET['foo'] = 'bar' # E: This QueryDict is immutable.
|
||||
x = 1 # E: Statement is unreachable
|
||||
expect_fail: true
|
||||
|
||||
@@ -61,3 +61,34 @@
|
||||
|
||||
class Article(models.Model):
|
||||
pass
|
||||
|
||||
- case: generic_form_views_different_form_classes
|
||||
main: |
|
||||
from django.views.generic.edit import CreateView
|
||||
from django import forms
|
||||
from myapp.models import Article
|
||||
|
||||
class ArticleModelForm(forms.ModelForm[Article]):
|
||||
class Meta:
|
||||
model = Article
|
||||
class SubArticleModelForm(ArticleModelForm):
|
||||
pass
|
||||
class AnotherArticleModelForm(forms.ModelForm[Article]):
|
||||
class Meta:
|
||||
model = Article
|
||||
|
||||
class MyCreateView(CreateView[Article, ArticleModelForm]):
|
||||
def some(self) -> None:
|
||||
reveal_type(self.get_form()) # N: Revealed type is "main.ArticleModelForm*"
|
||||
reveal_type(self.get_form(SubArticleModelForm)) # N: Revealed type is "main.SubArticleModelForm"
|
||||
reveal_type(self.get_form(AnotherArticleModelForm)) # N: Revealed type is "main.AnotherArticleModelForm" # E: Argument 1 to "get_form" of "FormMixin" has incompatible type "Type[AnotherArticleModelForm]"; expected "Optional[Type[ArticleModelForm]]"
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
|
||||
class Article(models.Model):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user