fixes for FormMixin's get_form/get_form_class

This commit is contained in:
Maxim Kurnikov
2019-02-25 04:01:36 +03:00
parent c09a97e005
commit df5c70c703
11 changed files with 139 additions and 55 deletions

View File

@@ -44,7 +44,6 @@ class UserCreationForm(forms.ModelForm):
password2: Any = ...
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def clean_password2(self) -> str: ...
def save(self, commit: bool = ...) -> User: ...
class UserChangeForm(forms.ModelForm):
auto_id: str

View File

@@ -1,15 +1,12 @@
from datetime import datetime
from io import StringIO, TextIOWrapper
from typing import Any, List, Optional, Tuple, Union
from typing import Any, IO, List, Optional, Tuple
from django.core.files.base import File
from django.utils.functional import LazyObject
class Storage:
def open(self, name: str, mode: str = ...) -> File: ...
def save(
self, name: Optional[str], content: Union[StringIO, TextIOWrapper, File], max_length: Optional[int] = ...
) -> str: ...
def save(self, name: Optional[str], content: IO[Any], max_length: Optional[int] = ...) -> str: ...
def get_valid_name(self, name: str) -> str: ...
def get_available_name(self, name: str, max_length: Optional[int] = ...) -> str: ...
def generate_filename(self, filename: str) -> str: ...

View File

@@ -29,6 +29,7 @@ class BaseForm:
empty_permitted: bool = ...
fields: Dict[str, Any] = ...
renderer: BaseRenderer = ...
cleaned_data: Any = ...
def __init__(
self,
data: Optional[Mapping[str, Any]] = ...,
@@ -57,7 +58,6 @@ class BaseForm:
def non_field_errors(self) -> ErrorList: ...
def add_error(self, field: Optional[str], error: Union[ValidationError, str]) -> None: ...
def has_error(self, field: Any, code: Optional[Any] = ...): ...
cleaned_data: Any = ...
def full_clean(self) -> None: ...
def clean(self) -> Dict[str, Optional[Union[datetime, SimpleUploadedFile, QuerySet, str]]]: ...
def has_changed(self) -> bool: ...

View File

@@ -1,6 +1,6 @@
from collections import OrderedDict
from datetime import date, datetime
from typing import Any, Callable, Dict, Iterator, List, MutableMapping, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterator, List, MutableMapping, Optional, Sequence, Tuple, Type, Union, Mapping
from unittest.mock import MagicMock
from uuid import UUID
@@ -20,19 +20,17 @@ from typing_extensions import Literal
ALL_FIELDS: str
_Fields = Union[List[Union[Callable, str]], Tuple[str]]
_Fields = Union[List[Union[Callable, str]], Sequence[str], Literal["__all__"]]
_Labels = Dict[str, str]
_ErrorMessages = Dict[str, Dict[str, str]]
def model_to_dict(
instance: Model,
fields: Optional[_Fields] = ...,
exclude: Optional[Union[List[Union[Callable, str]], Tuple[str]]] = ...,
) -> Dict[str, Optional[Union[bool, date, float]]]: ...
instance: Model, fields: Optional[_Fields] = ..., exclude: Optional[_Fields] = ...
) -> Dict[str, Any]: ...
def fields_for_model(
model: Type[Model],
fields: Optional[_Fields] = ...,
exclude: Optional[Union[List[Union[Callable, str]], Tuple]] = ...,
exclude: Optional[_Fields] = ...,
widgets: Optional[Union[Dict[str, Type[Input]], Dict[str, Widget]]] = ...,
formfield_callback: Optional[Union[Callable, str]] = ...,
localized_fields: Optional[Union[Tuple[str], str]] = ...,
@@ -42,12 +40,12 @@ def fields_for_model(
field_classes: Optional[Dict[str, Type[CharField]]] = ...,
*,
apply_limit_choices_to: bool = ...
) -> OrderedDict: ...
) -> Dict[str, Any]: ...
class ModelFormOptions:
model: Optional[Type[Model]] = ...
fields: Optional[_Fields] = ...
exclude: Optional[Union[List[Union[Callable, str]], Tuple, str]] = ...
exclude: Optional[_Fields] = ...
widgets: Optional[Dict[str, Union[Widget, Input]]] = ...
localized_fields: Optional[Union[Tuple[str], str]] = ...
labels: Optional[_Labels] = ...
@@ -57,14 +55,14 @@ class ModelFormOptions:
def __init__(self, options: Optional[type] = ...) -> None: ...
class ModelFormMetaclass(DeclarativeFieldsMetaclass):
def __new__(mcs, name: str, bases: Sequence[Type[ModelForm]], attrs: Dict[str, Any]) -> Type[ModelForm]: ...
def __new__(mcs, name: str, bases: Sequence[Type[Any]], attrs: Dict[str, Any]) -> Type[ModelForm]: ...
class BaseModelForm(BaseForm):
instance: Any = ...
def __init__(
self,
data: Optional[Dict[str, Any]] = ...,
files: Optional[Dict[str, File]] = ...,
data: Optional[Mapping[str, Any]] = ...,
files: Optional[Mapping[str, File]] = ...,
auto_id: Union[bool, str] = ...,
prefix: Optional[str] = ...,
initial: Optional[Dict[str, Any]] = ...,
@@ -72,21 +70,21 @@ class BaseModelForm(BaseForm):
label_suffix: Optional[str] = ...,
empty_permitted: bool = ...,
instance: Optional[Model] = ...,
use_required_attribute: None = ...,
use_required_attribute: Optional[bool] = ...,
renderer: Any = ...,
) -> None: ...
def clean(self) -> Dict[str, Any]: ...
def validate_unique(self) -> None: ...
save_m2m: Any = ...
def save(self, commit: bool = ...) -> Model: ...
def save(self, commit: bool = ...) -> Any: ...
class ModelForm(BaseModelForm): ...
def modelform_factory(
model: Type[Model],
form: Type[ModelForm] = ...,
fields: Optional[Union[Sequence[str], Literal["__all__"]]] = ...,
exclude: Optional[Sequence[str]] = ...,
fields: Optional[_Fields] = ...,
exclude: Optional[_Fields] = ...,
formfield_callback: Optional[Union[str, Callable[[models.Field], Field]]] = ...,
widgets: Optional[MutableMapping[str, Widget]] = ...,
localized_fields: Optional[Sequence[str]] = ...,
@@ -142,8 +140,8 @@ def modelformset_factory(
can_order: bool = ...,
min_num: Optional[int] = ...,
max_num: Optional[int] = ...,
fields: Optional[Union[str, Sequence[str]]] = ...,
exclude: Optional[Sequence[str]] = ...,
fields: Optional[_Fields] = ...,
exclude: Optional[_Fields] = ...,
widgets: Optional[Dict[str, Any]] = ...,
validate_max: bool = ...,
localized_fields: Optional[Sequence[str]] = ...,
@@ -181,8 +179,8 @@ def inlineformset_factory(
form: Type[ModelForm] = ...,
formset: Type[BaseInlineFormSet] = ...,
fk_name: Optional[str] = ...,
fields: Optional[Union[str, Sequence[str]]] = ...,
exclude: Optional[Sequence[str]] = ...,
fields: Optional[_Fields] = ...,
exclude: Optional[_Fields] = ...,
extra: int = ...,
can_order: bool = ...,
can_delete: bool = ...,
@@ -239,7 +237,7 @@ class ModelChoiceField(ChoiceField):
self,
queryset: Optional[Union[Manager, QuerySet]],
*,
empty_label: str = ...,
empty_label: Optional[str] = ...,
required: bool = ...,
widget: Optional[Any] = ...,
label: Optional[Any] = ...,

View File

@@ -4,7 +4,7 @@ from typing import Dict, Optional
from mypy.checker import TypeChecker
from mypy.nodes import AssignmentStmt, ClassDef, Expression, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, \
TypeInfo
from mypy.plugin import FunctionContext
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance, NoneTyp, Type, TypeOfAny, TypeVarType, UnionType
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
@@ -22,8 +22,13 @@ QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager'
MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager'
RELATED_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.RelatedManager'
BASEFORM_CLASS_FULLNAME = 'django.forms.forms.BaseForm'
FORM_CLASS_FULLNAME = 'django.forms.forms.Form'
MODELFORM_CLASS_FULLNAME = 'django.forms.models.ModelForm'
FORM_MIXIN_CLASS_FULLNAME = 'django.views.generic.edit.FormMixin'
MANAGER_CLASSES = {
MANAGER_CLASS_FULLNAME,
RELATED_MANAGER_CLASS_FULLNAME,
@@ -125,7 +130,7 @@ def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance:
return Instance(type_to_fill.type, typevar_values)
def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]:
def get_argument_by_name(ctx: typing.Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]:
"""Return the expression for the specific argument.
This helper should only be used with non-star arguments.
@@ -140,7 +145,7 @@ def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression
return args[0]
def get_argument_type_by_name(ctx: FunctionContext, name: str) -> Optional[Type]:
def get_argument_type_by_name(ctx: typing.Union[FunctionContext, MethodContext], name: str) -> Optional[Type]:
"""Return the type for the specific argument.
This helper should only be used with non-star arguments.
@@ -177,8 +182,8 @@ def get_setting_expr(api: TypeChecker, setting_name: str) -> Optional[Expression
return None
def iter_over_assignments(
class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]
) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
if isinstance(class_or_module, ClassDef):
statements = class_or_module.defs.body
else:
@@ -281,3 +286,10 @@ def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
return metaclass_sym.node
return None
def get_assigned_value_for_class(type_info: TypeInfo, name: str) -> Optional[Expression]:
for lvalue, rvalue in iter_over_assignments(type_info.defn):
if isinstance(lvalue, NameExpr) and lvalue.name == name:
return rvalue
return None

View File

@@ -2,10 +2,10 @@ import os
from typing import Callable, Dict, Optional, Union, cast
from mypy.checker import TypeChecker
from mypy.nodes import MemberExpr, TypeInfo
from mypy.nodes import MemberExpr, TypeInfo, NameExpr
from mypy.options import Options
from mypy.plugin import AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType, UnionType
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType, UnionType, CallableType, NoneTyp
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.config import Config
@@ -35,10 +35,10 @@ def transform_manager_class(ctx: ClassDefContext) -> None:
sym.node.metadata['django']['manager_bases'][ctx.cls.fullname] = 1
def transform_modelform_class(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(helpers.MODELFORM_CLASS_FULLNAME)
def transform_form_class(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(helpers.BASEFORM_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
sym.node.metadata['django']['modelform_bases'][ctx.cls.fullname] = 1
sym.node.metadata['django']['baseform_bases'][ctx.cls.fullname] = 1
make_meta_nested_class_inherit_from_any(ctx)
@@ -158,6 +158,47 @@ class ExtractSettingType:
return ctx.default_attr_type
def transform_form_view(ctx: ClassDefContext) -> None:
form_class_value = helpers.get_assigned_value_for_class(ctx.cls.info, 'form_class')
if isinstance(form_class_value, NameExpr):
helpers.get_django_metadata(ctx.cls.info)['form_class'] = form_class_value.fullname
def extract_proper_type_for_get_form_class(ctx: MethodContext) -> Type:
object_type = ctx.type
if not isinstance(object_type, Instance):
return ctx.default_return_type
form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None)
if not form_class_fullname:
return ctx.default_return_type
return TypeType(ctx.api.named_generic_type(form_class_fullname, []))
def extract_proper_type_for_get_form(ctx: MethodContext) -> Type:
object_type = ctx.type
if not isinstance(object_type, Instance):
return ctx.default_return_type
form_class_type = helpers.get_argument_type_by_name(ctx, 'form_class')
if form_class_type is None or isinstance(form_class_type, NoneTyp):
# extract from specified form_class in metadata
form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None)
if not form_class_fullname:
return ctx.default_return_type
return ctx.api.named_generic_type(form_class_fullname, [])
if isinstance(form_class_type, TypeType) and isinstance(form_class_type.item, Instance):
return form_class_type.item
if isinstance(form_class_type, CallableType) and isinstance(form_class_type.ret_type, Instance):
return form_class_type.ret_type
return ctx.default_return_type
class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
@@ -186,8 +227,7 @@ class DjangoPlugin(Plugin):
def _get_current_model_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (model_sym.node.metadata
.setdefault('django', {})
return (helpers.get_django_metadata(model_sym.node)
.setdefault('model_bases', {helpers.MODEL_CLASS_FULLNAME: 1}))
else:
return {}
@@ -195,18 +235,18 @@ class DjangoPlugin(Plugin):
def _get_current_manager_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.MANAGER_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (model_sym.node.metadata
.setdefault('django', {})
return (helpers.get_django_metadata(model_sym.node)
.setdefault('manager_bases', {helpers.MANAGER_CLASS_FULLNAME: 1}))
else:
return {}
def _get_current_modelform_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.MODELFORM_CLASS_FULLNAME)
def _get_current_form_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.BASEFORM_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (model_sym.node.metadata
.setdefault('django', {})
.setdefault('modelform_bases', {helpers.MODELFORM_CLASS_FULLNAME: 1}))
return (helpers.get_django_metadata(model_sym.node)
.setdefault('baseform_bases', {helpers.BASEFORM_CLASS_FULLNAME: 1,
helpers.FORM_CLASS_FULLNAME: 1,
helpers.MODELFORM_CLASS_FULLNAME: 1}))
else:
return {}
@@ -229,6 +269,17 @@ class DjangoPlugin(Plugin):
def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]:
class_name, _, method_name = fullname.rpartition('.')
if method_name == 'get_form_class':
sym = self.lookup_fully_qualified(class_name)
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME):
return extract_proper_type_for_get_form_class
if method_name == 'get_form':
sym = self.lookup_fully_qualified(class_name)
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME):
return extract_proper_type_for_get_form
if fullname in {'django.apps.registry.Apps.get_model',
'django.db.migrations.state.StateApps.get_model'}:
return determine_model_cls_from_string_for_migrations
@@ -254,8 +305,12 @@ class DjangoPlugin(Plugin):
if fullname in self._get_current_manager_bases():
return transform_manager_class
if fullname in self._get_current_modelform_bases():
return transform_modelform_class
if fullname in self._get_current_form_bases():
return transform_form_class
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME):
return transform_form_view
return None

View File

@@ -26,9 +26,7 @@ class ModelClassInitializer(metaclass=ABCMeta):
if meta_node is None:
return None
for lvalue, rvalue in iter_over_assignments(meta_node.defn):
if isinstance(lvalue, NameExpr) and lvalue.name == name:
return rvalue
return helpers.get_assigned_value_for_class(meta_node, name)
def is_abstract_model(self) -> bool:
is_abstract_expr = self.get_meta_attribute('abstract')

View File

@@ -1,7 +1,7 @@
#!/usr/local/bin/xonsh
try:
pip install wheel
pip install wheel twine
python setup.py sdist bdist_wheel --universal
twine upload dist/*

View File

@@ -4,7 +4,6 @@ ignore_missing_imports = True
check_untyped_defs = True
warn_no_return = False
show_traceback = True
warn_redundant_casts = True
allow_redefinition = True
incremental = False

View File

@@ -228,6 +228,10 @@ IGNORED_ERRORS = {
'"object" has no attribute "items"',
'"Field" has no attribute "many_to_many"'
],
'model_forms': [
'Argument "instance" to "InvalidModelForm" has incompatible type "Type[Category]"; expected "Optional[Model]"',
'Invalid type "NewForm"'
],
'model_fields': [
'Incompatible types in assignment (expression has type "Type[Person]", variable has type',
'Unexpected keyword argument "name" for "Person"',

View File

@@ -16,4 +16,26 @@ class CategoryForm(forms.ModelForm):
fields = '__all__'
class CompositeForm(ArticleForm, CategoryForm):
pass
[out]
[/CASE]
[CASE formview_methods_on_forms_return_proper_types]
from typing import Any
from django import forms
from django.views.generic.edit import FormView
class MyForm(forms.ModelForm):
pass
class MyForm2(forms.ModelForm):
pass
class MyView(FormView):
form_class = MyForm
def post(self, request, *args: Any, **kwds: Any):
form_class = self.get_form_class()
reveal_type(form_class) # E: Revealed type is 'Type[main.MyForm]'
reveal_type(self.get_form(None)) # E: Revealed type is 'main.MyForm'
reveal_type(self.get_form()) # E: Revealed type is 'main.MyForm'
reveal_type(self.get_form(form_class)) # E: Revealed type is 'main.MyForm'
reveal_type(self.get_form(MyForm2)) # E: Revealed type is 'main.MyForm2'
[/CASE]