mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-07 12:44:29 +08:00
add support for _meta.get_field() typechecking
This commit is contained in:
@@ -1,13 +1,13 @@
|
|||||||
from collections import OrderedDict
|
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from io import BufferedReader, StringIO, TextIOWrapper
|
from io import BufferedReader, StringIO, TextIOWrapper
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, Iterable, List, Mapping, Optional, Type, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from django.core.management.base import OutputWrapper
|
from django.core.management.base import OutputWrapper
|
||||||
from django.db.models.base import Model
|
from django.db.models.base import Model
|
||||||
from django.db.models.fields.related import ForeignKey, ManyToManyField
|
from django.db.models.fields.related import ForeignKey, ManyToManyField
|
||||||
from django.db.models.query import QuerySet
|
|
||||||
|
from django.db.models.fields import Field
|
||||||
|
|
||||||
class SerializerDoesNotExist(KeyError): ...
|
class SerializerDoesNotExist(KeyError): ...
|
||||||
class SerializationError(Exception): ...
|
class SerializationError(Exception): ...
|
||||||
@@ -43,7 +43,7 @@ class Serializer:
|
|||||||
first: bool = ...
|
first: bool = ...
|
||||||
def serialize(
|
def serialize(
|
||||||
self,
|
self,
|
||||||
queryset: Union[Iterator[Any], List[Model], QuerySet],
|
queryset: Iterable[Model],
|
||||||
*,
|
*,
|
||||||
stream: Optional[Any] = ...,
|
stream: Optional[Any] = ...,
|
||||||
fields: Optional[Any] = ...,
|
fields: Optional[Any] = ...,
|
||||||
@@ -52,7 +52,7 @@ class Serializer:
|
|||||||
progress_output: Optional[Any] = ...,
|
progress_output: Optional[Any] = ...,
|
||||||
object_count: int = ...,
|
object_count: int = ...,
|
||||||
**options: Any
|
**options: Any
|
||||||
) -> Optional[Union[List[OrderedDict], bytes, str]]: ...
|
) -> Any: ...
|
||||||
def start_serialization(self) -> None: ...
|
def start_serialization(self) -> None: ...
|
||||||
def end_serialization(self) -> None: ...
|
def end_serialization(self) -> None: ...
|
||||||
def start_object(self, obj: Any) -> None: ...
|
def start_object(self, obj: Any) -> None: ...
|
||||||
@@ -72,13 +72,16 @@ class Deserializer:
|
|||||||
class DeserializedObject:
|
class DeserializedObject:
|
||||||
object: Any = ...
|
object: Any = ...
|
||||||
m2m_data: Dict[str, List[int]] = ...
|
m2m_data: Dict[str, List[int]] = ...
|
||||||
def __init__(self, obj: Model, m2m_data: Optional[Dict[str, List[int]]] = ...) -> None: ...
|
deferred_fields: Mapping[Field, Any]
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
obj: Model,
|
||||||
|
m2m_data: Optional[Dict[str, List[int]]] = ...,
|
||||||
|
deferred_fields: Optional[Mapping[Field, Any]] = ...,
|
||||||
|
) -> None: ...
|
||||||
def save(self, save_m2m: bool = ..., using: Optional[str] = ..., **kwargs: Any) -> None: ...
|
def save(self, save_m2m: bool = ..., using: Optional[str] = ..., **kwargs: Any) -> None: ...
|
||||||
|
def save_deferred_fields(self, using: Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
def build_instance(Model: Type[Model], data: Dict[str, Optional[Union[date, int, str, UUID]]], db: str) -> Model: ...
|
def build_instance(Model: Type[Model], data: Dict[str, Optional[Union[date, int, str, UUID]]], db: str) -> Model: ...
|
||||||
def deserialize_m2m_values(
|
def deserialize_m2m_values(field: ManyToManyField, field_value: Any, using: str) -> List[Any]: ...
|
||||||
field: ManyToManyField, field_value: Union[List[List[str]], List[int]], using: str
|
def deserialize_fk_value(field: ForeignKey, field_value: Any, using: str) -> Any: ...
|
||||||
) -> List[int]: ...
|
|
||||||
def deserialize_fk_value(
|
|
||||||
field: ForeignKey, field_value: Optional[Union[List[str], Tuple[str], int, str]], using: str
|
|
||||||
) -> Optional[Union[int, str, UUID]]: ...
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, ClassVar
|
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, ClassVar, Type
|
||||||
|
|
||||||
from django.db.models.manager import Manager
|
from django.db.models.manager import Manager
|
||||||
|
|
||||||
@@ -6,16 +6,16 @@ from django.core.checks.messages import CheckMessage
|
|||||||
|
|
||||||
from django.db.models.options import Options
|
from django.db.models.options import Options
|
||||||
|
|
||||||
class ModelBase(type): ...
|
|
||||||
|
|
||||||
_Self = TypeVar("_Self", bound="Model")
|
_Self = TypeVar("_Self", bound="Model")
|
||||||
|
|
||||||
|
class ModelBase(type): ...
|
||||||
|
|
||||||
class Model(metaclass=ModelBase):
|
class Model(metaclass=ModelBase):
|
||||||
class DoesNotExist(Exception): ...
|
class DoesNotExist(Exception): ...
|
||||||
class MultipleObjectsReturned(Exception): ...
|
class MultipleObjectsReturned(Exception): ...
|
||||||
class Meta: ...
|
class Meta: ...
|
||||||
_meta: Options
|
|
||||||
_default_manager: Manager[Model]
|
_default_manager: Manager[Model]
|
||||||
|
_meta: Options[Any]
|
||||||
pk: Any = ...
|
pk: Any = ...
|
||||||
def __init__(self: _Self, *args, **kwargs) -> None: ...
|
def __init__(self: _Self, *args, **kwargs) -> None: ...
|
||||||
def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ...
|
def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ...
|
||||||
|
|||||||
@@ -39,12 +39,14 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
|
|||||||
max_length: Optional[int]
|
max_length: Optional[int]
|
||||||
model: Type[Model]
|
model: Type[Model]
|
||||||
name: str
|
name: str
|
||||||
|
verbose_name: str
|
||||||
blank: bool = ...
|
blank: bool = ...
|
||||||
null: bool = ...
|
null: bool = ...
|
||||||
editable: bool = ...
|
editable: bool = ...
|
||||||
choices: Optional[_FieldChoices] = ...
|
choices: Optional[_FieldChoices] = ...
|
||||||
db_column: Optional[str]
|
db_column: Optional[str]
|
||||||
column: str
|
column: str
|
||||||
|
error_messages: _ErrorMessagesToOverride
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
verbose_name: Optional[Union[str, bytes]] = ...,
|
verbose_name: Optional[Union[str, bytes]] = ...,
|
||||||
@@ -78,6 +80,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
|
|||||||
def get_prep_value(self, value: Any) -> Any: ...
|
def get_prep_value(self, value: Any) -> Any: ...
|
||||||
def get_internal_type(self) -> str: ...
|
def get_internal_type(self) -> str: ...
|
||||||
def formfield(self, **kwargs) -> FormField: ...
|
def formfield(self, **kwargs) -> FormField: ...
|
||||||
|
def save_form_data(self, instance: Model, data: Any) -> None: ...
|
||||||
def contribute_to_class(self, cls: Type[Model], name: str, private_only: bool = ...) -> None: ...
|
def contribute_to_class(self, cls: Type[Model], name: str, private_only: bool = ...) -> None: ...
|
||||||
def to_python(self, value: Any) -> Any: ...
|
def to_python(self, value: Any) -> Any: ...
|
||||||
def clean(self, value: Any, model_instance: Optional[Model]) -> Any: ...
|
def clean(self, value: Any, model_instance: Optional[Model]) -> Any: ...
|
||||||
@@ -91,6 +94,8 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
|
|||||||
def has_default(self) -> bool: ...
|
def has_default(self) -> bool: ...
|
||||||
def get_default(self) -> Any: ...
|
def get_default(self) -> Any: ...
|
||||||
def check(self, **kwargs: Any) -> List[checks.Error]: ...
|
def check(self, **kwargs: Any) -> List[checks.Error]: ...
|
||||||
|
@property
|
||||||
|
def validators(self) -> List[_ValidatorCallable]: ...
|
||||||
|
|
||||||
class IntegerField(Field[_ST, _GT]):
|
class IntegerField(Field[_ST, _GT]):
|
||||||
_pyi_private_set_type: Union[float, int, str, Combinable]
|
_pyi_private_set_type: Union[float, int, str, Combinable]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import collections
|
import collections
|
||||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, TypeVar, Generic
|
||||||
|
|
||||||
from django.apps.config import AppConfig
|
from django.apps.config import AppConfig
|
||||||
from django.apps.registry import Apps
|
from django.apps.registry import Apps
|
||||||
@@ -29,7 +29,9 @@ def make_immutable_fields_list(
|
|||||||
name: str, data: Union[Iterator[Any], List[Union[ArrayField, CIText]], List[Union[Field, FieldCacheMixin]]]
|
name: str, data: Union[Iterator[Any], List[Union[ArrayField, CIText]], List[Union[Field, FieldCacheMixin]]]
|
||||||
) -> ImmutableList: ...
|
) -> ImmutableList: ...
|
||||||
|
|
||||||
class Options:
|
_M = TypeVar('_M', bound=Model)
|
||||||
|
|
||||||
|
class Options(Generic[_M]):
|
||||||
base_manager: Manager
|
base_manager: Manager
|
||||||
concrete_fields: ImmutableList
|
concrete_fields: ImmutableList
|
||||||
default_manager: Manager
|
default_manager: Manager
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
|
|||||||
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
|
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
|
||||||
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
|
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
|
||||||
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
|
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
|
||||||
def create(self, **kwargs: Any) -> _T: ...
|
def create(self, *args: Any, **kwargs: Any) -> _T: ...
|
||||||
def bulk_create(
|
def bulk_create(
|
||||||
self, objs: Iterable[Model], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
|
self, objs: Iterable[Model], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
|
||||||
) -> List[_T]: ...
|
) -> List[_T]: ...
|
||||||
|
|||||||
@@ -35,3 +35,4 @@ RELATED_FIELDS_CLASSES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
MIGRATION_CLASS_FULLNAME = 'django.db.migrations.migration.Migration'
|
MIGRATION_CLASS_FULLNAME = 'django.db.migrations.migration.Migration'
|
||||||
|
OPTIONS_CLASS_FULLNAME = 'django.db.models.options.Options'
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from mypy.types import Type as MypyType
|
|||||||
|
|
||||||
from mypy_django_plugin.django.context import DjangoContext
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
from mypy_django_plugin.lib import fullnames, helpers
|
from mypy_django_plugin.lib import fullnames, helpers
|
||||||
from mypy_django_plugin.transformers import fields, forms, init_create, querysets, settings
|
from mypy_django_plugin.transformers import fields, forms, init_create, querysets, settings, meta
|
||||||
from mypy_django_plugin.transformers.models import process_model_class
|
from mypy_django_plugin.transformers.models import process_model_class
|
||||||
|
|
||||||
|
|
||||||
@@ -168,15 +168,20 @@ class NewSemanalDjangoPlugin(Plugin):
|
|||||||
return forms.extract_proper_type_for_get_form
|
return forms.extract_proper_type_for_get_form
|
||||||
|
|
||||||
if method_name == 'values':
|
if method_name == 'values':
|
||||||
model_info = self._get_typeinfo_or_none(class_fullname)
|
info = self._get_typeinfo_or_none(class_fullname)
|
||||||
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
|
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
|
||||||
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
|
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
|
||||||
|
|
||||||
if method_name == 'values_list':
|
if method_name == 'values_list':
|
||||||
model_info = self._get_typeinfo_or_none(class_fullname)
|
info = self._get_typeinfo_or_none(class_fullname)
|
||||||
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
|
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
|
||||||
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
|
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
|
||||||
|
|
||||||
|
if method_name == 'get_field':
|
||||||
|
info = self._get_typeinfo_or_none(class_fullname)
|
||||||
|
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
|
||||||
|
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)
|
||||||
|
|
||||||
manager_classes = self._get_current_manager_bases()
|
manager_classes = self._get_current_manager_bases()
|
||||||
if class_fullname in manager_classes and method_name == 'create':
|
if class_fullname in manager_classes and method_name == 'create':
|
||||||
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
|
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
|
||||||
|
|||||||
38
mypy_django_plugin/transformers/meta.py
Normal file
38
mypy_django_plugin/transformers/meta.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from django.core.exceptions import FieldDoesNotExist
|
||||||
|
from mypy.plugin import MethodContext
|
||||||
|
from mypy.types import AnyType, Type as MypyType, TypeOfAny, Instance
|
||||||
|
|
||||||
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
|
from mypy_django_plugin.lib import fullnames, helpers
|
||||||
|
|
||||||
|
|
||||||
|
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
|
||||||
|
field_info = helpers.lookup_fully_qualified_typeinfo(ctx.api, field_fullname)
|
||||||
|
return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)])
|
||||||
|
|
||||||
|
|
||||||
|
def return_proper_field_type_from_get_field(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||||
|
model_type = ctx.type.args[0]
|
||||||
|
if not isinstance(model_type, Instance):
|
||||||
|
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
||||||
|
|
||||||
|
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
||||||
|
if model_cls is None:
|
||||||
|
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
||||||
|
|
||||||
|
field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name')
|
||||||
|
if field_name_expr is None:
|
||||||
|
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
||||||
|
|
||||||
|
field_name = helpers.resolve_string_attribute_value(field_name_expr, ctx, django_context)
|
||||||
|
if field_name is None:
|
||||||
|
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
||||||
|
|
||||||
|
try:
|
||||||
|
field = model_cls._meta.get_field(field_name)
|
||||||
|
except FieldDoesNotExist as exc:
|
||||||
|
ctx.api.fail(exc.args[0], ctx.context)
|
||||||
|
return AnyType(TypeOfAny.from_error)
|
||||||
|
|
||||||
|
field_fullname = helpers.get_class_fullname(field.__class__)
|
||||||
|
return _get_field_instance(ctx, field_fullname)
|
||||||
@@ -2,15 +2,15 @@ from abc import ABCMeta, abstractmethod
|
|||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from django.db.models.fields import DateTimeField, DateField
|
|
||||||
from django.db.models.fields.related import ForeignKey
|
from django.db.models.fields.related import ForeignKey
|
||||||
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
|
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
|
||||||
from mypy.newsemanal.semanal import NewSemanticAnalyzer
|
from mypy.newsemanal.semanal import NewSemanticAnalyzer
|
||||||
from mypy.nodes import MDEF, SymbolTableNode, TypeInfo, Var, Argument, ARG_NAMED_OPT, ARG_STAR2
|
from mypy.nodes import ARG_STAR2, Argument, MDEF, SymbolTableNode, TypeInfo, Var
|
||||||
from mypy.plugin import ClassDefContext
|
from mypy.plugin import ClassDefContext
|
||||||
from mypy.plugins import common
|
from mypy.plugins import common
|
||||||
from mypy.types import Instance, TypeOfAny, AnyType
|
from mypy.types import AnyType, Instance, TypeOfAny
|
||||||
|
|
||||||
|
from django.db.models.fields import DateField, DateTimeField
|
||||||
from mypy_django_plugin.django.context import DjangoContext
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
from mypy_django_plugin.lib import fullnames, helpers
|
from mypy_django_plugin.lib import fullnames, helpers
|
||||||
from mypy_django_plugin.transformers import fields
|
from mypy_django_plugin.transformers import fields
|
||||||
@@ -177,6 +177,15 @@ class AddExtraFieldMethods(ModelClassInitializer):
|
|||||||
return_type=return_type)
|
return_type=return_type)
|
||||||
|
|
||||||
|
|
||||||
|
class AddMetaOptionsAttribute(ModelClassInitializer):
|
||||||
|
def run(self):
|
||||||
|
if '_meta' not in self.model_classdef.info.names:
|
||||||
|
options_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.OPTIONS_CLASS_FULLNAME)
|
||||||
|
self.add_new_node_to_model_class('_meta',
|
||||||
|
Instance(options_info, [
|
||||||
|
Instance(self.model_classdef.info, [])
|
||||||
|
]))
|
||||||
|
|
||||||
|
|
||||||
def process_model_class(ctx: ClassDefContext,
|
def process_model_class(ctx: ClassDefContext,
|
||||||
django_context: DjangoContext) -> None:
|
django_context: DjangoContext) -> None:
|
||||||
@@ -186,6 +195,7 @@ def process_model_class(ctx: ClassDefContext,
|
|||||||
AddRelatedModelsId,
|
AddRelatedModelsId,
|
||||||
AddManagers,
|
AddManagers,
|
||||||
AddExtraFieldMethods,
|
AddExtraFieldMethods,
|
||||||
|
AddMetaOptionsAttribute
|
||||||
]
|
]
|
||||||
for initializer_cls in initializers:
|
for initializer_cls in initializers:
|
||||||
try:
|
try:
|
||||||
|
|||||||
33
test-data/typecheck/models/test_meta_options.yml
Normal file
33
test-data/typecheck/models/test_meta_options.yml
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
- case: meta_attribute_has_a_type_of_current_model
|
||||||
|
main: |
|
||||||
|
from myapp.models import MyUser
|
||||||
|
reveal_type(MyUser._meta) # N: Revealed type is 'django.db.models.options.Options[myapp.models.MyUser]'
|
||||||
|
installed_apps:
|
||||||
|
- myapp
|
||||||
|
files:
|
||||||
|
- path: myapp/__init__.py
|
||||||
|
- path: myapp/models.py
|
||||||
|
content: |
|
||||||
|
from django.db import models
|
||||||
|
class MyUser(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
- case: get_field_returns_proper_field_type
|
||||||
|
main: |
|
||||||
|
from myapp.models import MyUser
|
||||||
|
reveal_type(MyUser._meta.get_field('name')) # N: Revealed type is 'django.db.models.fields.CharField[Any, Any]'
|
||||||
|
reveal_type(MyUser._meta.get_field('age')) # N: Revealed type is 'django.db.models.fields.IntegerField[Any, Any]'
|
||||||
|
reveal_type(MyUser._meta.get_field('unknown'))
|
||||||
|
out: |
|
||||||
|
main:4: note: Revealed type is 'Any'
|
||||||
|
main:4: error: MyUser has no field named 'unknown'
|
||||||
|
installed_apps:
|
||||||
|
- myapp
|
||||||
|
files:
|
||||||
|
- path: myapp/__init__.py
|
||||||
|
- path: myapp/models.py
|
||||||
|
content: |
|
||||||
|
from django.db import models
|
||||||
|
class MyUser(models.Model):
|
||||||
|
name = models.CharField(max_length=100)
|
||||||
|
age = models.IntegerField()
|
||||||
Reference in New Issue
Block a user