From df5c70c703d450249e82a3df938d8b1c1d822128 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Mon, 25 Feb 2019 04:01:36 +0300 Subject: [PATCH] fixes for FormMixin's get_form/get_form_class --- django-stubs/contrib/auth/forms.pyi | 1 - django-stubs/core/files/storage.pyi | 7 +- django-stubs/forms/forms.pyi | 2 +- django-stubs/forms/models.pyi | 40 +++++------ mypy_django_plugin/helpers.py | 22 ++++-- mypy_django_plugin/main.py | 87 ++++++++++++++++++----- mypy_django_plugin/transformers/models.py | 4 +- release.xsh | 2 +- scripts/mypy.ini | 1 - scripts/typecheck_tests.py | 4 ++ test-data/typecheck/forms.test | 24 ++++++- 11 files changed, 139 insertions(+), 55 deletions(-) diff --git a/django-stubs/contrib/auth/forms.pyi b/django-stubs/contrib/auth/forms.pyi index ace768b..041286a 100644 --- a/django-stubs/contrib/auth/forms.pyi +++ b/django-stubs/contrib/auth/forms.pyi @@ -44,7 +44,6 @@ class UserCreationForm(forms.ModelForm): password2: Any = ... def __init__(self, *args: Any, **kwargs: Any) -> None: ... def clean_password2(self) -> str: ... - def save(self, commit: bool = ...) -> User: ... class UserChangeForm(forms.ModelForm): auto_id: str diff --git a/django-stubs/core/files/storage.pyi b/django-stubs/core/files/storage.pyi index 0f27f0a..b4b0242 100644 --- a/django-stubs/core/files/storage.pyi +++ b/django-stubs/core/files/storage.pyi @@ -1,15 +1,12 @@ from datetime import datetime -from io import StringIO, TextIOWrapper -from typing import Any, List, Optional, Tuple, Union +from typing import Any, IO, List, Optional, Tuple from django.core.files.base import File from django.utils.functional import LazyObject class Storage: def open(self, name: str, mode: str = ...) -> File: ... - def save( - self, name: Optional[str], content: Union[StringIO, TextIOWrapper, File], max_length: Optional[int] = ... - ) -> str: ... + def save(self, name: Optional[str], content: IO[Any], max_length: Optional[int] = ...) -> str: ... def get_valid_name(self, name: str) -> str: ... def get_available_name(self, name: str, max_length: Optional[int] = ...) -> str: ... def generate_filename(self, filename: str) -> str: ... diff --git a/django-stubs/forms/forms.pyi b/django-stubs/forms/forms.pyi index f62448d..ede1337 100644 --- a/django-stubs/forms/forms.pyi +++ b/django-stubs/forms/forms.pyi @@ -29,6 +29,7 @@ class BaseForm: empty_permitted: bool = ... fields: Dict[str, Any] = ... renderer: BaseRenderer = ... + cleaned_data: Any = ... def __init__( self, data: Optional[Mapping[str, Any]] = ..., @@ -57,7 +58,6 @@ class BaseForm: def non_field_errors(self) -> ErrorList: ... def add_error(self, field: Optional[str], error: Union[ValidationError, str]) -> None: ... def has_error(self, field: Any, code: Optional[Any] = ...): ... - cleaned_data: Any = ... def full_clean(self) -> None: ... def clean(self) -> Dict[str, Optional[Union[datetime, SimpleUploadedFile, QuerySet, str]]]: ... def has_changed(self) -> bool: ... diff --git a/django-stubs/forms/models.pyi b/django-stubs/forms/models.pyi index 09b42c2..aa46757 100644 --- a/django-stubs/forms/models.pyi +++ b/django-stubs/forms/models.pyi @@ -1,6 +1,6 @@ from collections import OrderedDict from datetime import date, datetime -from typing import Any, Callable, Dict, Iterator, List, MutableMapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterator, List, MutableMapping, Optional, Sequence, Tuple, Type, Union, Mapping from unittest.mock import MagicMock from uuid import UUID @@ -20,19 +20,17 @@ from typing_extensions import Literal ALL_FIELDS: str -_Fields = Union[List[Union[Callable, str]], Tuple[str]] +_Fields = Union[List[Union[Callable, str]], Sequence[str], Literal["__all__"]] _Labels = Dict[str, str] _ErrorMessages = Dict[str, Dict[str, str]] def model_to_dict( - instance: Model, - fields: Optional[_Fields] = ..., - exclude: Optional[Union[List[Union[Callable, str]], Tuple[str]]] = ..., -) -> Dict[str, Optional[Union[bool, date, float]]]: ... + instance: Model, fields: Optional[_Fields] = ..., exclude: Optional[_Fields] = ... +) -> Dict[str, Any]: ... def fields_for_model( model: Type[Model], fields: Optional[_Fields] = ..., - exclude: Optional[Union[List[Union[Callable, str]], Tuple]] = ..., + exclude: Optional[_Fields] = ..., widgets: Optional[Union[Dict[str, Type[Input]], Dict[str, Widget]]] = ..., formfield_callback: Optional[Union[Callable, str]] = ..., localized_fields: Optional[Union[Tuple[str], str]] = ..., @@ -42,12 +40,12 @@ def fields_for_model( field_classes: Optional[Dict[str, Type[CharField]]] = ..., *, apply_limit_choices_to: bool = ... -) -> OrderedDict: ... +) -> Dict[str, Any]: ... class ModelFormOptions: model: Optional[Type[Model]] = ... fields: Optional[_Fields] = ... - exclude: Optional[Union[List[Union[Callable, str]], Tuple, str]] = ... + exclude: Optional[_Fields] = ... widgets: Optional[Dict[str, Union[Widget, Input]]] = ... localized_fields: Optional[Union[Tuple[str], str]] = ... labels: Optional[_Labels] = ... @@ -57,14 +55,14 @@ class ModelFormOptions: def __init__(self, options: Optional[type] = ...) -> None: ... class ModelFormMetaclass(DeclarativeFieldsMetaclass): - def __new__(mcs, name: str, bases: Sequence[Type[ModelForm]], attrs: Dict[str, Any]) -> Type[ModelForm]: ... + def __new__(mcs, name: str, bases: Sequence[Type[Any]], attrs: Dict[str, Any]) -> Type[ModelForm]: ... class BaseModelForm(BaseForm): instance: Any = ... def __init__( self, - data: Optional[Dict[str, Any]] = ..., - files: Optional[Dict[str, File]] = ..., + data: Optional[Mapping[str, Any]] = ..., + files: Optional[Mapping[str, File]] = ..., auto_id: Union[bool, str] = ..., prefix: Optional[str] = ..., initial: Optional[Dict[str, Any]] = ..., @@ -72,21 +70,21 @@ class BaseModelForm(BaseForm): label_suffix: Optional[str] = ..., empty_permitted: bool = ..., instance: Optional[Model] = ..., - use_required_attribute: None = ..., + use_required_attribute: Optional[bool] = ..., renderer: Any = ..., ) -> None: ... def clean(self) -> Dict[str, Any]: ... def validate_unique(self) -> None: ... save_m2m: Any = ... - def save(self, commit: bool = ...) -> Model: ... + def save(self, commit: bool = ...) -> Any: ... class ModelForm(BaseModelForm): ... def modelform_factory( model: Type[Model], form: Type[ModelForm] = ..., - fields: Optional[Union[Sequence[str], Literal["__all__"]]] = ..., - exclude: Optional[Sequence[str]] = ..., + fields: Optional[_Fields] = ..., + exclude: Optional[_Fields] = ..., formfield_callback: Optional[Union[str, Callable[[models.Field], Field]]] = ..., widgets: Optional[MutableMapping[str, Widget]] = ..., localized_fields: Optional[Sequence[str]] = ..., @@ -142,8 +140,8 @@ def modelformset_factory( can_order: bool = ..., min_num: Optional[int] = ..., max_num: Optional[int] = ..., - fields: Optional[Union[str, Sequence[str]]] = ..., - exclude: Optional[Sequence[str]] = ..., + fields: Optional[_Fields] = ..., + exclude: Optional[_Fields] = ..., widgets: Optional[Dict[str, Any]] = ..., validate_max: bool = ..., localized_fields: Optional[Sequence[str]] = ..., @@ -181,8 +179,8 @@ def inlineformset_factory( form: Type[ModelForm] = ..., formset: Type[BaseInlineFormSet] = ..., fk_name: Optional[str] = ..., - fields: Optional[Union[str, Sequence[str]]] = ..., - exclude: Optional[Sequence[str]] = ..., + fields: Optional[_Fields] = ..., + exclude: Optional[_Fields] = ..., extra: int = ..., can_order: bool = ..., can_delete: bool = ..., @@ -239,7 +237,7 @@ class ModelChoiceField(ChoiceField): self, queryset: Optional[Union[Manager, QuerySet]], *, - empty_label: str = ..., + empty_label: Optional[str] = ..., required: bool = ..., widget: Optional[Any] = ..., label: Optional[Any] = ..., diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index 93a2482..97acc39 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -4,7 +4,7 @@ from typing import Dict, Optional from mypy.checker import TypeChecker from mypy.nodes import AssignmentStmt, ClassDef, Expression, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, \ TypeInfo -from mypy.plugin import FunctionContext +from mypy.plugin import FunctionContext, MethodContext from mypy.types import AnyType, Instance, NoneTyp, Type, TypeOfAny, TypeVarType, UnionType MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' @@ -22,8 +22,13 @@ QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet' BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager' MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager' RELATED_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.RelatedManager' + +BASEFORM_CLASS_FULLNAME = 'django.forms.forms.BaseForm' +FORM_CLASS_FULLNAME = 'django.forms.forms.Form' MODELFORM_CLASS_FULLNAME = 'django.forms.models.ModelForm' +FORM_MIXIN_CLASS_FULLNAME = 'django.views.generic.edit.FormMixin' + MANAGER_CLASSES = { MANAGER_CLASS_FULLNAME, RELATED_MANAGER_CLASS_FULLNAME, @@ -125,7 +130,7 @@ def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance: return Instance(type_to_fill.type, typevar_values) -def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]: +def get_argument_by_name(ctx: typing.Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]: """Return the expression for the specific argument. This helper should only be used with non-star arguments. @@ -140,7 +145,7 @@ def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression return args[0] -def get_argument_type_by_name(ctx: FunctionContext, name: str) -> Optional[Type]: +def get_argument_type_by_name(ctx: typing.Union[FunctionContext, MethodContext], name: str) -> Optional[Type]: """Return the type for the specific argument. This helper should only be used with non-star arguments. @@ -177,8 +182,8 @@ def get_setting_expr(api: TypeChecker, setting_name: str) -> Optional[Expression return None -def iter_over_assignments( - class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]: +def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile] + ) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]: if isinstance(class_or_module, ClassDef): statements = class_or_module.defs.body else: @@ -281,3 +286,10 @@ def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo] if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo): return metaclass_sym.node return None + + +def get_assigned_value_for_class(type_info: TypeInfo, name: str) -> Optional[Expression]: + for lvalue, rvalue in iter_over_assignments(type_info.defn): + if isinstance(lvalue, NameExpr) and lvalue.name == name: + return rvalue + return None diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 8942cf4..78eeb31 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -2,10 +2,10 @@ import os from typing import Callable, Dict, Optional, Union, cast from mypy.checker import TypeChecker -from mypy.nodes import MemberExpr, TypeInfo +from mypy.nodes import MemberExpr, TypeInfo, NameExpr from mypy.options import Options from mypy.plugin import AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin -from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType, UnionType +from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType, UnionType, CallableType, NoneTyp from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin.config import Config @@ -35,10 +35,10 @@ def transform_manager_class(ctx: ClassDefContext) -> None: sym.node.metadata['django']['manager_bases'][ctx.cls.fullname] = 1 -def transform_modelform_class(ctx: ClassDefContext) -> None: - sym = ctx.api.lookup_fully_qualified_or_none(helpers.MODELFORM_CLASS_FULLNAME) +def transform_form_class(ctx: ClassDefContext) -> None: + sym = ctx.api.lookup_fully_qualified_or_none(helpers.BASEFORM_CLASS_FULLNAME) if sym is not None and isinstance(sym.node, TypeInfo): - sym.node.metadata['django']['modelform_bases'][ctx.cls.fullname] = 1 + sym.node.metadata['django']['baseform_bases'][ctx.cls.fullname] = 1 make_meta_nested_class_inherit_from_any(ctx) @@ -158,6 +158,47 @@ class ExtractSettingType: return ctx.default_attr_type +def transform_form_view(ctx: ClassDefContext) -> None: + form_class_value = helpers.get_assigned_value_for_class(ctx.cls.info, 'form_class') + if isinstance(form_class_value, NameExpr): + helpers.get_django_metadata(ctx.cls.info)['form_class'] = form_class_value.fullname + + +def extract_proper_type_for_get_form_class(ctx: MethodContext) -> Type: + object_type = ctx.type + if not isinstance(object_type, Instance): + return ctx.default_return_type + + form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None) + if not form_class_fullname: + return ctx.default_return_type + + return TypeType(ctx.api.named_generic_type(form_class_fullname, [])) + + +def extract_proper_type_for_get_form(ctx: MethodContext) -> Type: + object_type = ctx.type + if not isinstance(object_type, Instance): + return ctx.default_return_type + + form_class_type = helpers.get_argument_type_by_name(ctx, 'form_class') + if form_class_type is None or isinstance(form_class_type, NoneTyp): + # extract from specified form_class in metadata + form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None) + if not form_class_fullname: + return ctx.default_return_type + + return ctx.api.named_generic_type(form_class_fullname, []) + + if isinstance(form_class_type, TypeType) and isinstance(form_class_type.item, Instance): + return form_class_type.item + + if isinstance(form_class_type, CallableType) and isinstance(form_class_type.ret_type, Instance): + return form_class_type.ret_type + + return ctx.default_return_type + + class DjangoPlugin(Plugin): def __init__(self, options: Options) -> None: super().__init__(options) @@ -186,8 +227,7 @@ class DjangoPlugin(Plugin): def _get_current_model_bases(self) -> Dict[str, int]: model_sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME) if model_sym is not None and isinstance(model_sym.node, TypeInfo): - return (model_sym.node.metadata - .setdefault('django', {}) + return (helpers.get_django_metadata(model_sym.node) .setdefault('model_bases', {helpers.MODEL_CLASS_FULLNAME: 1})) else: return {} @@ -195,18 +235,18 @@ class DjangoPlugin(Plugin): def _get_current_manager_bases(self) -> Dict[str, int]: model_sym = self.lookup_fully_qualified(helpers.MANAGER_CLASS_FULLNAME) if model_sym is not None and isinstance(model_sym.node, TypeInfo): - return (model_sym.node.metadata - .setdefault('django', {}) + return (helpers.get_django_metadata(model_sym.node) .setdefault('manager_bases', {helpers.MANAGER_CLASS_FULLNAME: 1})) else: return {} - def _get_current_modelform_bases(self) -> Dict[str, int]: - model_sym = self.lookup_fully_qualified(helpers.MODELFORM_CLASS_FULLNAME) + def _get_current_form_bases(self) -> Dict[str, int]: + model_sym = self.lookup_fully_qualified(helpers.BASEFORM_CLASS_FULLNAME) if model_sym is not None and isinstance(model_sym.node, TypeInfo): - return (model_sym.node.metadata - .setdefault('django', {}) - .setdefault('modelform_bases', {helpers.MODELFORM_CLASS_FULLNAME: 1})) + return (helpers.get_django_metadata(model_sym.node) + .setdefault('baseform_bases', {helpers.BASEFORM_CLASS_FULLNAME: 1, + helpers.FORM_CLASS_FULLNAME: 1, + helpers.MODELFORM_CLASS_FULLNAME: 1})) else: return {} @@ -229,6 +269,17 @@ class DjangoPlugin(Plugin): def get_method_hook(self, fullname: str ) -> Optional[Callable[[MethodContext], Type]]: + class_name, _, method_name = fullname.rpartition('.') + if method_name == 'get_form_class': + sym = self.lookup_fully_qualified(class_name) + if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): + return extract_proper_type_for_get_form_class + + if method_name == 'get_form': + sym = self.lookup_fully_qualified(class_name) + if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): + return extract_proper_type_for_get_form + if fullname in {'django.apps.registry.Apps.get_model', 'django.db.migrations.state.StateApps.get_model'}: return determine_model_cls_from_string_for_migrations @@ -254,8 +305,12 @@ class DjangoPlugin(Plugin): if fullname in self._get_current_manager_bases(): return transform_manager_class - if fullname in self._get_current_modelform_bases(): - return transform_modelform_class + if fullname in self._get_current_form_bases(): + return transform_form_class + + sym = self.lookup_fully_qualified(fullname) + if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): + return transform_form_view return None diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index ffebd07..83da0c5 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -26,9 +26,7 @@ class ModelClassInitializer(metaclass=ABCMeta): if meta_node is None: return None - for lvalue, rvalue in iter_over_assignments(meta_node.defn): - if isinstance(lvalue, NameExpr) and lvalue.name == name: - return rvalue + return helpers.get_assigned_value_for_class(meta_node, name) def is_abstract_model(self) -> bool: is_abstract_expr = self.get_meta_attribute('abstract') diff --git a/release.xsh b/release.xsh index 50d0c0d..16043c4 100755 --- a/release.xsh +++ b/release.xsh @@ -1,7 +1,7 @@ #!/usr/local/bin/xonsh try: - pip install wheel + pip install wheel twine python setup.py sdist bdist_wheel --universal twine upload dist/* diff --git a/scripts/mypy.ini b/scripts/mypy.ini index 0933a35..e23a468 100644 --- a/scripts/mypy.ini +++ b/scripts/mypy.ini @@ -4,7 +4,6 @@ ignore_missing_imports = True check_untyped_defs = True warn_no_return = False show_traceback = True -warn_redundant_casts = True allow_redefinition = True incremental = False diff --git a/scripts/typecheck_tests.py b/scripts/typecheck_tests.py index 112a450..9d06044 100644 --- a/scripts/typecheck_tests.py +++ b/scripts/typecheck_tests.py @@ -228,6 +228,10 @@ IGNORED_ERRORS = { '"object" has no attribute "items"', '"Field" has no attribute "many_to_many"' ], + 'model_forms': [ + 'Argument "instance" to "InvalidModelForm" has incompatible type "Type[Category]"; expected "Optional[Model]"', + 'Invalid type "NewForm"' + ], 'model_fields': [ 'Incompatible types in assignment (expression has type "Type[Person]", variable has type', 'Unexpected keyword argument "name" for "Person"', diff --git a/test-data/typecheck/forms.test b/test-data/typecheck/forms.test index c00239e..649289e 100644 --- a/test-data/typecheck/forms.test +++ b/test-data/typecheck/forms.test @@ -16,4 +16,26 @@ class CategoryForm(forms.ModelForm): fields = '__all__' class CompositeForm(ArticleForm, CategoryForm): pass -[out] \ No newline at end of file +[/CASE] + +[CASE formview_methods_on_forms_return_proper_types] +from typing import Any +from django import forms +from django.views.generic.edit import FormView + +class MyForm(forms.ModelForm): + pass +class MyForm2(forms.ModelForm): + pass + +class MyView(FormView): + form_class = MyForm + + def post(self, request, *args: Any, **kwds: Any): + form_class = self.get_form_class() + reveal_type(form_class) # E: Revealed type is 'Type[main.MyForm]' + reveal_type(self.get_form(None)) # E: Revealed type is 'main.MyForm' + reveal_type(self.get_form()) # E: Revealed type is 'main.MyForm' + reveal_type(self.get_form(form_class)) # E: Revealed type is 'main.MyForm' + reveal_type(self.get_form(MyForm2)) # E: Revealed type is 'main.MyForm2' +[/CASE]