mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 12:14:28 +08:00
add support for _meta.get_field() typechecking
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import date
|
||||
from io import BufferedReader, StringIO, TextIOWrapper
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from django.core.management.base import OutputWrapper
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.fields.related import ForeignKey, ManyToManyField
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
from django.db.models.fields import Field
|
||||
|
||||
class SerializerDoesNotExist(KeyError): ...
|
||||
class SerializationError(Exception): ...
|
||||
@@ -43,7 +43,7 @@ class Serializer:
|
||||
first: bool = ...
|
||||
def serialize(
|
||||
self,
|
||||
queryset: Union[Iterator[Any], List[Model], QuerySet],
|
||||
queryset: Iterable[Model],
|
||||
*,
|
||||
stream: Optional[Any] = ...,
|
||||
fields: Optional[Any] = ...,
|
||||
@@ -52,7 +52,7 @@ class Serializer:
|
||||
progress_output: Optional[Any] = ...,
|
||||
object_count: int = ...,
|
||||
**options: Any
|
||||
) -> Optional[Union[List[OrderedDict], bytes, str]]: ...
|
||||
) -> Any: ...
|
||||
def start_serialization(self) -> None: ...
|
||||
def end_serialization(self) -> None: ...
|
||||
def start_object(self, obj: Any) -> None: ...
|
||||
@@ -72,13 +72,16 @@ class Deserializer:
|
||||
class DeserializedObject:
|
||||
object: Any = ...
|
||||
m2m_data: Dict[str, List[int]] = ...
|
||||
def __init__(self, obj: Model, m2m_data: Optional[Dict[str, List[int]]] = ...) -> None: ...
|
||||
deferred_fields: Mapping[Field, Any]
|
||||
def __init__(
|
||||
self,
|
||||
obj: Model,
|
||||
m2m_data: Optional[Dict[str, List[int]]] = ...,
|
||||
deferred_fields: Optional[Mapping[Field, Any]] = ...,
|
||||
) -> None: ...
|
||||
def save(self, save_m2m: bool = ..., using: Optional[str] = ..., **kwargs: Any) -> None: ...
|
||||
def save_deferred_fields(self, using: Optional[str] = ...) -> None: ...
|
||||
|
||||
def build_instance(Model: Type[Model], data: Dict[str, Optional[Union[date, int, str, UUID]]], db: str) -> Model: ...
|
||||
def deserialize_m2m_values(
|
||||
field: ManyToManyField, field_value: Union[List[List[str]], List[int]], using: str
|
||||
) -> List[int]: ...
|
||||
def deserialize_fk_value(
|
||||
field: ForeignKey, field_value: Optional[Union[List[str], Tuple[str], int, str]], using: str
|
||||
) -> Optional[Union[int, str, UUID]]: ...
|
||||
def deserialize_m2m_values(field: ManyToManyField, field_value: Any, using: str) -> List[Any]: ...
|
||||
def deserialize_fk_value(field: ForeignKey, field_value: Any, using: str) -> Any: ...
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, ClassVar
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, ClassVar, Type
|
||||
|
||||
from django.db.models.manager import Manager
|
||||
|
||||
@@ -6,16 +6,16 @@ from django.core.checks.messages import CheckMessage
|
||||
|
||||
from django.db.models.options import Options
|
||||
|
||||
class ModelBase(type): ...
|
||||
|
||||
_Self = TypeVar("_Self", bound="Model")
|
||||
|
||||
class ModelBase(type): ...
|
||||
|
||||
class Model(metaclass=ModelBase):
|
||||
class DoesNotExist(Exception): ...
|
||||
class MultipleObjectsReturned(Exception): ...
|
||||
class Meta: ...
|
||||
_meta: Options
|
||||
_default_manager: Manager[Model]
|
||||
_meta: Options[Any]
|
||||
pk: Any = ...
|
||||
def __init__(self: _Self, *args, **kwargs) -> None: ...
|
||||
def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ...
|
||||
|
||||
@@ -39,12 +39,14 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
|
||||
max_length: Optional[int]
|
||||
model: Type[Model]
|
||||
name: str
|
||||
verbose_name: str
|
||||
blank: bool = ...
|
||||
null: bool = ...
|
||||
editable: bool = ...
|
||||
choices: Optional[_FieldChoices] = ...
|
||||
db_column: Optional[str]
|
||||
column: str
|
||||
error_messages: _ErrorMessagesToOverride
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name: Optional[Union[str, bytes]] = ...,
|
||||
@@ -78,6 +80,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
|
||||
def get_prep_value(self, value: Any) -> Any: ...
|
||||
def get_internal_type(self) -> str: ...
|
||||
def formfield(self, **kwargs) -> FormField: ...
|
||||
def save_form_data(self, instance: Model, data: Any) -> None: ...
|
||||
def contribute_to_class(self, cls: Type[Model], name: str, private_only: bool = ...) -> None: ...
|
||||
def to_python(self, value: Any) -> Any: ...
|
||||
def clean(self, value: Any, model_instance: Optional[Model]) -> Any: ...
|
||||
@@ -91,6 +94,8 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
|
||||
def has_default(self) -> bool: ...
|
||||
def get_default(self) -> Any: ...
|
||||
def check(self, **kwargs: Any) -> List[checks.Error]: ...
|
||||
@property
|
||||
def validators(self) -> List[_ValidatorCallable]: ...
|
||||
|
||||
class IntegerField(Field[_ST, _GT]):
|
||||
_pyi_private_set_type: Union[float, int, str, Combinable]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import collections
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, TypeVar, Generic
|
||||
|
||||
from django.apps.config import AppConfig
|
||||
from django.apps.registry import Apps
|
||||
@@ -29,7 +29,9 @@ def make_immutable_fields_list(
|
||||
name: str, data: Union[Iterator[Any], List[Union[ArrayField, CIText]], List[Union[Field, FieldCacheMixin]]]
|
||||
) -> ImmutableList: ...
|
||||
|
||||
class Options:
|
||||
_M = TypeVar('_M', bound=Model)
|
||||
|
||||
class Options(Generic[_M]):
|
||||
base_manager: Manager
|
||||
concrete_fields: ImmutableList
|
||||
default_manager: Manager
|
||||
|
||||
@@ -78,7 +78,7 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
|
||||
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
|
||||
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
|
||||
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
|
||||
def create(self, **kwargs: Any) -> _T: ...
|
||||
def create(self, *args: Any, **kwargs: Any) -> _T: ...
|
||||
def bulk_create(
|
||||
self, objs: Iterable[Model], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
|
||||
) -> List[_T]: ...
|
||||
|
||||
@@ -35,3 +35,4 @@ RELATED_FIELDS_CLASSES = {
|
||||
}
|
||||
|
||||
MIGRATION_CLASS_FULLNAME = 'django.db.migrations.migration.Migration'
|
||||
OPTIONS_CLASS_FULLNAME = 'django.db.models.options.Options'
|
||||
|
||||
@@ -11,7 +11,7 @@ from mypy.types import Type as MypyType
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
from mypy_django_plugin.transformers import fields, forms, init_create, querysets, settings
|
||||
from mypy_django_plugin.transformers import fields, forms, init_create, querysets, settings, meta
|
||||
from mypy_django_plugin.transformers.models import process_model_class
|
||||
|
||||
|
||||
@@ -168,15 +168,20 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
return forms.extract_proper_type_for_get_form
|
||||
|
||||
if method_name == 'values':
|
||||
model_info = self._get_typeinfo_or_none(class_fullname)
|
||||
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
|
||||
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
|
||||
|
||||
if method_name == 'values_list':
|
||||
model_info = self._get_typeinfo_or_none(class_fullname)
|
||||
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
|
||||
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
|
||||
|
||||
if method_name == 'get_field':
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
|
||||
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)
|
||||
|
||||
manager_classes = self._get_current_manager_bases()
|
||||
if class_fullname in manager_classes and method_name == 'create':
|
||||
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
|
||||
|
||||
38
mypy_django_plugin/transformers/meta.py
Normal file
38
mypy_django_plugin/transformers/meta.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from mypy.plugin import MethodContext
|
||||
from mypy.types import AnyType, Type as MypyType, TypeOfAny, Instance
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
|
||||
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
|
||||
field_info = helpers.lookup_fully_qualified_typeinfo(ctx.api, field_fullname)
|
||||
return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)])
|
||||
|
||||
|
||||
def return_proper_field_type_from_get_field(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||
model_type = ctx.type.args[0]
|
||||
if not isinstance(model_type, Instance):
|
||||
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
||||
|
||||
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
||||
if model_cls is None:
|
||||
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
||||
|
||||
field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name')
|
||||
if field_name_expr is None:
|
||||
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
||||
|
||||
field_name = helpers.resolve_string_attribute_value(field_name_expr, ctx, django_context)
|
||||
if field_name is None:
|
||||
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
||||
|
||||
try:
|
||||
field = model_cls._meta.get_field(field_name)
|
||||
except FieldDoesNotExist as exc:
|
||||
ctx.api.fail(exc.args[0], ctx.context)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
field_fullname = helpers.get_class_fullname(field.__class__)
|
||||
return _get_field_instance(ctx, field_fullname)
|
||||
@@ -2,15 +2,15 @@ from abc import ABCMeta, abstractmethod
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import cast
|
||||
|
||||
from django.db.models.fields import DateTimeField, DateField
|
||||
from django.db.models.fields.related import ForeignKey
|
||||
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
|
||||
from mypy.newsemanal.semanal import NewSemanticAnalyzer
|
||||
from mypy.nodes import MDEF, SymbolTableNode, TypeInfo, Var, Argument, ARG_NAMED_OPT, ARG_STAR2
|
||||
from mypy.nodes import ARG_STAR2, Argument, MDEF, SymbolTableNode, TypeInfo, Var
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.plugins import common
|
||||
from mypy.types import Instance, TypeOfAny, AnyType
|
||||
from mypy.types import AnyType, Instance, TypeOfAny
|
||||
|
||||
from django.db.models.fields import DateField, DateTimeField
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
from mypy_django_plugin.transformers import fields
|
||||
@@ -177,6 +177,15 @@ class AddExtraFieldMethods(ModelClassInitializer):
|
||||
return_type=return_type)
|
||||
|
||||
|
||||
class AddMetaOptionsAttribute(ModelClassInitializer):
|
||||
def run(self):
|
||||
if '_meta' not in self.model_classdef.info.names:
|
||||
options_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.OPTIONS_CLASS_FULLNAME)
|
||||
self.add_new_node_to_model_class('_meta',
|
||||
Instance(options_info, [
|
||||
Instance(self.model_classdef.info, [])
|
||||
]))
|
||||
|
||||
|
||||
def process_model_class(ctx: ClassDefContext,
|
||||
django_context: DjangoContext) -> None:
|
||||
@@ -186,6 +195,7 @@ def process_model_class(ctx: ClassDefContext,
|
||||
AddRelatedModelsId,
|
||||
AddManagers,
|
||||
AddExtraFieldMethods,
|
||||
AddMetaOptionsAttribute
|
||||
]
|
||||
for initializer_cls in initializers:
|
||||
try:
|
||||
|
||||
33
test-data/typecheck/models/test_meta_options.yml
Normal file
33
test-data/typecheck/models/test_meta_options.yml
Normal file
@@ -0,0 +1,33 @@
|
||||
- case: meta_attribute_has_a_type_of_current_model
|
||||
main: |
|
||||
from myapp.models import MyUser
|
||||
reveal_type(MyUser._meta) # N: Revealed type is 'django.db.models.options.Options[myapp.models.MyUser]'
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
class MyUser(models.Model):
|
||||
pass
|
||||
|
||||
- case: get_field_returns_proper_field_type
|
||||
main: |
|
||||
from myapp.models import MyUser
|
||||
reveal_type(MyUser._meta.get_field('name')) # N: Revealed type is 'django.db.models.fields.CharField[Any, Any]'
|
||||
reveal_type(MyUser._meta.get_field('age')) # N: Revealed type is 'django.db.models.fields.IntegerField[Any, Any]'
|
||||
reveal_type(MyUser._meta.get_field('unknown'))
|
||||
out: |
|
||||
main:4: note: Revealed type is 'Any'
|
||||
main:4: error: MyUser has no field named 'unknown'
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
class MyUser(models.Model):
|
||||
name = models.CharField(max_length=100)
|
||||
age = models.IntegerField()
|
||||
Reference in New Issue
Block a user