add support for _meta.get_field() typechecking

This commit is contained in:
Maxim Kurnikov
2019-07-19 16:27:13 +03:00
parent 5bb1bc250d
commit fc9843bea6
10 changed files with 124 additions and 27 deletions

View File

@@ -1,13 +1,13 @@
from collections import OrderedDict
from datetime import date from datetime import date
from io import BufferedReader, StringIO, TextIOWrapper from io import BufferedReader, StringIO, TextIOWrapper
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union from typing import Any, Dict, Iterable, List, Mapping, Optional, Type, Union
from uuid import UUID from uuid import UUID
from django.core.management.base import OutputWrapper from django.core.management.base import OutputWrapper
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey, ManyToManyField from django.db.models.fields.related import ForeignKey, ManyToManyField
from django.db.models.query import QuerySet
from django.db.models.fields import Field
class SerializerDoesNotExist(KeyError): ... class SerializerDoesNotExist(KeyError): ...
class SerializationError(Exception): ... class SerializationError(Exception): ...
@@ -43,7 +43,7 @@ class Serializer:
first: bool = ... first: bool = ...
def serialize( def serialize(
self, self,
queryset: Union[Iterator[Any], List[Model], QuerySet], queryset: Iterable[Model],
*, *,
stream: Optional[Any] = ..., stream: Optional[Any] = ...,
fields: Optional[Any] = ..., fields: Optional[Any] = ...,
@@ -52,7 +52,7 @@ class Serializer:
progress_output: Optional[Any] = ..., progress_output: Optional[Any] = ...,
object_count: int = ..., object_count: int = ...,
**options: Any **options: Any
) -> Optional[Union[List[OrderedDict], bytes, str]]: ... ) -> Any: ...
def start_serialization(self) -> None: ... def start_serialization(self) -> None: ...
def end_serialization(self) -> None: ... def end_serialization(self) -> None: ...
def start_object(self, obj: Any) -> None: ... def start_object(self, obj: Any) -> None: ...
@@ -72,13 +72,16 @@ class Deserializer:
class DeserializedObject: class DeserializedObject:
object: Any = ... object: Any = ...
m2m_data: Dict[str, List[int]] = ... m2m_data: Dict[str, List[int]] = ...
def __init__(self, obj: Model, m2m_data: Optional[Dict[str, List[int]]] = ...) -> None: ... deferred_fields: Mapping[Field, Any]
def __init__(
self,
obj: Model,
m2m_data: Optional[Dict[str, List[int]]] = ...,
deferred_fields: Optional[Mapping[Field, Any]] = ...,
) -> None: ...
def save(self, save_m2m: bool = ..., using: Optional[str] = ..., **kwargs: Any) -> None: ... def save(self, save_m2m: bool = ..., using: Optional[str] = ..., **kwargs: Any) -> None: ...
def save_deferred_fields(self, using: Optional[str] = ...) -> None: ...
def build_instance(Model: Type[Model], data: Dict[str, Optional[Union[date, int, str, UUID]]], db: str) -> Model: ... def build_instance(Model: Type[Model], data: Dict[str, Optional[Union[date, int, str, UUID]]], db: str) -> Model: ...
def deserialize_m2m_values( def deserialize_m2m_values(field: ManyToManyField, field_value: Any, using: str) -> List[Any]: ...
field: ManyToManyField, field_value: Union[List[List[str]], List[int]], using: str def deserialize_fk_value(field: ForeignKey, field_value: Any, using: str) -> Any: ...
) -> List[int]: ...
def deserialize_fk_value(
field: ForeignKey, field_value: Optional[Union[List[str], Tuple[str], int, str]], using: str
) -> Optional[Union[int, str, UUID]]: ...

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, ClassVar from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, ClassVar, Type
from django.db.models.manager import Manager from django.db.models.manager import Manager
@@ -6,16 +6,16 @@ from django.core.checks.messages import CheckMessage
from django.db.models.options import Options from django.db.models.options import Options
class ModelBase(type): ...
_Self = TypeVar("_Self", bound="Model") _Self = TypeVar("_Self", bound="Model")
class ModelBase(type): ...
class Model(metaclass=ModelBase): class Model(metaclass=ModelBase):
class DoesNotExist(Exception): ... class DoesNotExist(Exception): ...
class MultipleObjectsReturned(Exception): ... class MultipleObjectsReturned(Exception): ...
class Meta: ... class Meta: ...
_meta: Options
_default_manager: Manager[Model] _default_manager: Manager[Model]
_meta: Options[Any]
pk: Any = ... pk: Any = ...
def __init__(self: _Self, *args, **kwargs) -> None: ... def __init__(self: _Self, *args, **kwargs) -> None: ...
def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ... def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ...

View File

@@ -39,12 +39,14 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
max_length: Optional[int] max_length: Optional[int]
model: Type[Model] model: Type[Model]
name: str name: str
verbose_name: str
blank: bool = ... blank: bool = ...
null: bool = ... null: bool = ...
editable: bool = ... editable: bool = ...
choices: Optional[_FieldChoices] = ... choices: Optional[_FieldChoices] = ...
db_column: Optional[str] db_column: Optional[str]
column: str column: str
error_messages: _ErrorMessagesToOverride
def __init__( def __init__(
self, self,
verbose_name: Optional[Union[str, bytes]] = ..., verbose_name: Optional[Union[str, bytes]] = ...,
@@ -78,6 +80,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
def get_prep_value(self, value: Any) -> Any: ... def get_prep_value(self, value: Any) -> Any: ...
def get_internal_type(self) -> str: ... def get_internal_type(self) -> str: ...
def formfield(self, **kwargs) -> FormField: ... def formfield(self, **kwargs) -> FormField: ...
def save_form_data(self, instance: Model, data: Any) -> None: ...
def contribute_to_class(self, cls: Type[Model], name: str, private_only: bool = ...) -> None: ... def contribute_to_class(self, cls: Type[Model], name: str, private_only: bool = ...) -> None: ...
def to_python(self, value: Any) -> Any: ... def to_python(self, value: Any) -> Any: ...
def clean(self, value: Any, model_instance: Optional[Model]) -> Any: ... def clean(self, value: Any, model_instance: Optional[Model]) -> Any: ...
@@ -91,6 +94,8 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
def has_default(self) -> bool: ... def has_default(self) -> bool: ...
def get_default(self) -> Any: ... def get_default(self) -> Any: ...
def check(self, **kwargs: Any) -> List[checks.Error]: ... def check(self, **kwargs: Any) -> List[checks.Error]: ...
@property
def validators(self) -> List[_ValidatorCallable]: ...
class IntegerField(Field[_ST, _GT]): class IntegerField(Field[_ST, _GT]):
_pyi_private_set_type: Union[float, int, str, Combinable] _pyi_private_set_type: Union[float, int, str, Combinable]

View File

@@ -1,5 +1,5 @@
import collections import collections
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, TypeVar, Generic
from django.apps.config import AppConfig from django.apps.config import AppConfig
from django.apps.registry import Apps from django.apps.registry import Apps
@@ -29,7 +29,9 @@ def make_immutable_fields_list(
name: str, data: Union[Iterator[Any], List[Union[ArrayField, CIText]], List[Union[Field, FieldCacheMixin]]] name: str, data: Union[Iterator[Any], List[Union[ArrayField, CIText]], List[Union[Field, FieldCacheMixin]]]
) -> ImmutableList: ... ) -> ImmutableList: ...
class Options: _M = TypeVar('_M', bound=Model)
class Options(Generic[_M]):
base_manager: Manager base_manager: Manager
concrete_fields: ImmutableList concrete_fields: ImmutableList
default_manager: Manager default_manager: Manager

View File

@@ -78,7 +78,7 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ... def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ... def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
def get(self, *args: Any, **kwargs: Any) -> _Row: ... def get(self, *args: Any, **kwargs: Any) -> _Row: ...
def create(self, **kwargs: Any) -> _T: ... def create(self, *args: Any, **kwargs: Any) -> _T: ...
def bulk_create( def bulk_create(
self, objs: Iterable[Model], batch_size: Optional[int] = ..., ignore_conflicts: bool = ... self, objs: Iterable[Model], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
) -> List[_T]: ... ) -> List[_T]: ...

View File

@@ -35,3 +35,4 @@ RELATED_FIELDS_CLASSES = {
} }
MIGRATION_CLASS_FULLNAME = 'django.db.migrations.migration.Migration' MIGRATION_CLASS_FULLNAME = 'django.db.migrations.migration.Migration'
OPTIONS_CLASS_FULLNAME = 'django.db.models.options.Options'

View File

@@ -11,7 +11,7 @@ from mypy.types import Type as MypyType
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import fields, forms, init_create, querysets, settings from mypy_django_plugin.transformers import fields, forms, init_create, querysets, settings, meta
from mypy_django_plugin.transformers.models import process_model_class from mypy_django_plugin.transformers.models import process_model_class
@@ -168,15 +168,20 @@ class NewSemanalDjangoPlugin(Plugin):
return forms.extract_proper_type_for_get_form return forms.extract_proper_type_for_get_form
if method_name == 'values': if method_name == 'values':
model_info = self._get_typeinfo_or_none(class_fullname) info = self._get_typeinfo_or_none(class_fullname)
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context) return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
if method_name == 'values_list': if method_name == 'values_list':
model_info = self._get_typeinfo_or_none(class_fullname) info = self._get_typeinfo_or_none(class_fullname)
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context) return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
if method_name == 'get_field':
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)
manager_classes = self._get_current_manager_bases() manager_classes = self._get_current_manager_bases()
if class_fullname in manager_classes and method_name == 'create': if class_fullname in manager_classes and method_name == 'create':
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context) return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)

View File

@@ -0,0 +1,38 @@
from django.core.exceptions import FieldDoesNotExist
from mypy.plugin import MethodContext
from mypy.types import AnyType, Type as MypyType, TypeOfAny, Instance
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
field_info = helpers.lookup_fully_qualified_typeinfo(ctx.api, field_fullname)
return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)])
def return_proper_field_type_from_get_field(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
model_type = ctx.type.args[0]
if not isinstance(model_type, Instance):
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
if model_cls is None:
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name')
if field_name_expr is None:
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
field_name = helpers.resolve_string_attribute_value(field_name_expr, ctx, django_context)
if field_name is None:
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
try:
field = model_cls._meta.get_field(field_name)
except FieldDoesNotExist as exc:
ctx.api.fail(exc.args[0], ctx.context)
return AnyType(TypeOfAny.from_error)
field_fullname = helpers.get_class_fullname(field.__class__)
return _get_field_instance(ctx, field_fullname)

View File

@@ -2,15 +2,15 @@ from abc import ABCMeta, abstractmethod
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import cast from typing import cast
from django.db.models.fields import DateTimeField, DateField
from django.db.models.fields.related import ForeignKey from django.db.models.fields.related import ForeignKey
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
from mypy.newsemanal.semanal import NewSemanticAnalyzer from mypy.newsemanal.semanal import NewSemanticAnalyzer
from mypy.nodes import MDEF, SymbolTableNode, TypeInfo, Var, Argument, ARG_NAMED_OPT, ARG_STAR2 from mypy.nodes import ARG_STAR2, Argument, MDEF, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.plugins import common from mypy.plugins import common
from mypy.types import Instance, TypeOfAny, AnyType from mypy.types import AnyType, Instance, TypeOfAny
from django.db.models.fields import DateField, DateTimeField
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import fields from mypy_django_plugin.transformers import fields
@@ -177,6 +177,15 @@ class AddExtraFieldMethods(ModelClassInitializer):
return_type=return_type) return_type=return_type)
class AddMetaOptionsAttribute(ModelClassInitializer):
def run(self):
if '_meta' not in self.model_classdef.info.names:
options_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.OPTIONS_CLASS_FULLNAME)
self.add_new_node_to_model_class('_meta',
Instance(options_info, [
Instance(self.model_classdef.info, [])
]))
def process_model_class(ctx: ClassDefContext, def process_model_class(ctx: ClassDefContext,
django_context: DjangoContext) -> None: django_context: DjangoContext) -> None:
@@ -186,6 +195,7 @@ def process_model_class(ctx: ClassDefContext,
AddRelatedModelsId, AddRelatedModelsId,
AddManagers, AddManagers,
AddExtraFieldMethods, AddExtraFieldMethods,
AddMetaOptionsAttribute
] ]
for initializer_cls in initializers: for initializer_cls in initializers:
try: try:

View File

@@ -0,0 +1,33 @@
- case: meta_attribute_has_a_type_of_current_model
main: |
from myapp.models import MyUser
reveal_type(MyUser._meta) # N: Revealed type is 'django.db.models.options.Options[myapp.models.MyUser]'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyUser(models.Model):
pass
- case: get_field_returns_proper_field_type
main: |
from myapp.models import MyUser
reveal_type(MyUser._meta.get_field('name')) # N: Revealed type is 'django.db.models.fields.CharField[Any, Any]'
reveal_type(MyUser._meta.get_field('age')) # N: Revealed type is 'django.db.models.fields.IntegerField[Any, Any]'
reveal_type(MyUser._meta.get_field('unknown'))
out: |
main:4: note: Revealed type is 'Any'
main:4: error: MyUser has no field named 'unknown'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyUser(models.Model):
name = models.CharField(max_length=100)
age = models.IntegerField()