add support for _meta.get_field() typechecking

This commit is contained in:
Maxim Kurnikov
2019-07-19 16:27:13 +03:00
parent 5bb1bc250d
commit fc9843bea6
10 changed files with 124 additions and 27 deletions

View File

@@ -1,13 +1,13 @@
from collections import OrderedDict
from datetime import date
from io import BufferedReader, StringIO, TextIOWrapper
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterable, List, Mapping, Optional, Type, Union
from uuid import UUID
from django.core.management.base import OutputWrapper
from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey, ManyToManyField
from django.db.models.query import QuerySet
from django.db.models.fields import Field
class SerializerDoesNotExist(KeyError): ...
class SerializationError(Exception): ...
@@ -43,7 +43,7 @@ class Serializer:
first: bool = ...
def serialize(
self,
queryset: Union[Iterator[Any], List[Model], QuerySet],
queryset: Iterable[Model],
*,
stream: Optional[Any] = ...,
fields: Optional[Any] = ...,
@@ -52,7 +52,7 @@ class Serializer:
progress_output: Optional[Any] = ...,
object_count: int = ...,
**options: Any
) -> Optional[Union[List[OrderedDict], bytes, str]]: ...
) -> Any: ...
def start_serialization(self) -> None: ...
def end_serialization(self) -> None: ...
def start_object(self, obj: Any) -> None: ...
@@ -72,13 +72,16 @@ class Deserializer:
class DeserializedObject:
object: Any = ...
m2m_data: Dict[str, List[int]] = ...
def __init__(self, obj: Model, m2m_data: Optional[Dict[str, List[int]]] = ...) -> None: ...
deferred_fields: Mapping[Field, Any]
def __init__(
self,
obj: Model,
m2m_data: Optional[Dict[str, List[int]]] = ...,
deferred_fields: Optional[Mapping[Field, Any]] = ...,
) -> None: ...
def save(self, save_m2m: bool = ..., using: Optional[str] = ..., **kwargs: Any) -> None: ...
def save_deferred_fields(self, using: Optional[str] = ...) -> None: ...
def build_instance(Model: Type[Model], data: Dict[str, Optional[Union[date, int, str, UUID]]], db: str) -> Model: ...
def deserialize_m2m_values(
field: ManyToManyField, field_value: Union[List[List[str]], List[int]], using: str
) -> List[int]: ...
def deserialize_fk_value(
field: ForeignKey, field_value: Optional[Union[List[str], Tuple[str], int, str]], using: str
) -> Optional[Union[int, str, UUID]]: ...
def deserialize_m2m_values(field: ManyToManyField, field_value: Any, using: str) -> List[Any]: ...
def deserialize_fk_value(field: ForeignKey, field_value: Any, using: str) -> Any: ...

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, ClassVar
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, ClassVar, Type
from django.db.models.manager import Manager
@@ -6,16 +6,16 @@ from django.core.checks.messages import CheckMessage
from django.db.models.options import Options
class ModelBase(type): ...
_Self = TypeVar("_Self", bound="Model")
class ModelBase(type): ...
class Model(metaclass=ModelBase):
class DoesNotExist(Exception): ...
class MultipleObjectsReturned(Exception): ...
class Meta: ...
_meta: Options
_default_manager: Manager[Model]
_meta: Options[Any]
pk: Any = ...
def __init__(self: _Self, *args, **kwargs) -> None: ...
def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ...

View File

@@ -39,12 +39,14 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
max_length: Optional[int]
model: Type[Model]
name: str
verbose_name: str
blank: bool = ...
null: bool = ...
editable: bool = ...
choices: Optional[_FieldChoices] = ...
db_column: Optional[str]
column: str
error_messages: _ErrorMessagesToOverride
def __init__(
self,
verbose_name: Optional[Union[str, bytes]] = ...,
@@ -78,6 +80,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
def get_prep_value(self, value: Any) -> Any: ...
def get_internal_type(self) -> str: ...
def formfield(self, **kwargs) -> FormField: ...
def save_form_data(self, instance: Model, data: Any) -> None: ...
def contribute_to_class(self, cls: Type[Model], name: str, private_only: bool = ...) -> None: ...
def to_python(self, value: Any) -> Any: ...
def clean(self, value: Any, model_instance: Optional[Model]) -> Any: ...
@@ -91,6 +94,8 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
def has_default(self) -> bool: ...
def get_default(self) -> Any: ...
def check(self, **kwargs: Any) -> List[checks.Error]: ...
@property
def validators(self) -> List[_ValidatorCallable]: ...
class IntegerField(Field[_ST, _GT]):
_pyi_private_set_type: Union[float, int, str, Combinable]

View File

@@ -1,5 +1,5 @@
import collections
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, TypeVar, Generic
from django.apps.config import AppConfig
from django.apps.registry import Apps
@@ -29,7 +29,9 @@ def make_immutable_fields_list(
name: str, data: Union[Iterator[Any], List[Union[ArrayField, CIText]], List[Union[Field, FieldCacheMixin]]]
) -> ImmutableList: ...
class Options:
_M = TypeVar('_M', bound=Model)
class Options(Generic[_M]):
base_manager: Manager
concrete_fields: ImmutableList
default_manager: Manager

View File

@@ -78,7 +78,7 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
def create(self, **kwargs: Any) -> _T: ...
def create(self, *args: Any, **kwargs: Any) -> _T: ...
def bulk_create(
self, objs: Iterable[Model], batch_size: Optional[int] = ..., ignore_conflicts: bool = ...
) -> List[_T]: ...

View File

@@ -35,3 +35,4 @@ RELATED_FIELDS_CLASSES = {
}
MIGRATION_CLASS_FULLNAME = 'django.db.migrations.migration.Migration'
OPTIONS_CLASS_FULLNAME = 'django.db.models.options.Options'

View File

@@ -11,7 +11,7 @@ from mypy.types import Type as MypyType
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import fields, forms, init_create, querysets, settings
from mypy_django_plugin.transformers import fields, forms, init_create, querysets, settings, meta
from mypy_django_plugin.transformers.models import process_model_class
@@ -168,15 +168,20 @@ class NewSemanalDjangoPlugin(Plugin):
return forms.extract_proper_type_for_get_form
if method_name == 'values':
model_info = self._get_typeinfo_or_none(class_fullname)
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
if method_name == 'values_list':
model_info = self._get_typeinfo_or_none(class_fullname)
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
if method_name == 'get_field':
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)
manager_classes = self._get_current_manager_bases()
if class_fullname in manager_classes and method_name == 'create':
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)

View File

@@ -0,0 +1,38 @@
from django.core.exceptions import FieldDoesNotExist
from mypy.plugin import MethodContext
from mypy.types import AnyType, Type as MypyType, TypeOfAny, Instance
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
field_info = helpers.lookup_fully_qualified_typeinfo(ctx.api, field_fullname)
return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)])
def return_proper_field_type_from_get_field(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
model_type = ctx.type.args[0]
if not isinstance(model_type, Instance):
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
if model_cls is None:
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name')
if field_name_expr is None:
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
field_name = helpers.resolve_string_attribute_value(field_name_expr, ctx, django_context)
if field_name is None:
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
try:
field = model_cls._meta.get_field(field_name)
except FieldDoesNotExist as exc:
ctx.api.fail(exc.args[0], ctx.context)
return AnyType(TypeOfAny.from_error)
field_fullname = helpers.get_class_fullname(field.__class__)
return _get_field_instance(ctx, field_fullname)

View File

@@ -2,15 +2,15 @@ from abc import ABCMeta, abstractmethod
from abc import ABCMeta, abstractmethod
from typing import cast
from django.db.models.fields import DateTimeField, DateField
from django.db.models.fields.related import ForeignKey
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
from mypy.newsemanal.semanal import NewSemanticAnalyzer
from mypy.nodes import MDEF, SymbolTableNode, TypeInfo, Var, Argument, ARG_NAMED_OPT, ARG_STAR2
from mypy.nodes import ARG_STAR2, Argument, MDEF, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext
from mypy.plugins import common
from mypy.types import Instance, TypeOfAny, AnyType
from mypy.types import AnyType, Instance, TypeOfAny
from django.db.models.fields import DateField, DateTimeField
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import fields
@@ -177,6 +177,15 @@ class AddExtraFieldMethods(ModelClassInitializer):
return_type=return_type)
class AddMetaOptionsAttribute(ModelClassInitializer):
def run(self):
if '_meta' not in self.model_classdef.info.names:
options_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.OPTIONS_CLASS_FULLNAME)
self.add_new_node_to_model_class('_meta',
Instance(options_info, [
Instance(self.model_classdef.info, [])
]))
def process_model_class(ctx: ClassDefContext,
django_context: DjangoContext) -> None:
@@ -186,6 +195,7 @@ def process_model_class(ctx: ClassDefContext,
AddRelatedModelsId,
AddManagers,
AddExtraFieldMethods,
AddMetaOptionsAttribute
]
for initializer_cls in initializers:
try:

View File

@@ -0,0 +1,33 @@
- case: meta_attribute_has_a_type_of_current_model
main: |
from myapp.models import MyUser
reveal_type(MyUser._meta) # N: Revealed type is 'django.db.models.options.Options[myapp.models.MyUser]'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyUser(models.Model):
pass
- case: get_field_returns_proper_field_type
main: |
from myapp.models import MyUser
reveal_type(MyUser._meta.get_field('name')) # N: Revealed type is 'django.db.models.fields.CharField[Any, Any]'
reveal_type(MyUser._meta.get_field('age')) # N: Revealed type is 'django.db.models.fields.IntegerField[Any, Any]'
reveal_type(MyUser._meta.get_field('unknown'))
out: |
main:4: note: Revealed type is 'Any'
main:4: error: MyUser has no field named 'unknown'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyUser(models.Model):
name = models.CharField(max_length=100)
age = models.IntegerField()