preserve fallback to Any for unrecognized field types for init/create

This commit is contained in:
Maxim Kurnikov
2019-02-13 17:00:35 +03:00
parent b7f7713c5a
commit 70378b8f40
4 changed files with 45 additions and 24 deletions

View File

@@ -10,7 +10,7 @@ class Config:
ignore_missing_settings: bool = False ignore_missing_settings: bool = False
@classmethod @classmethod
def from_config_file(self, fpath: str) -> 'Config': def from_config_file(cls, fpath: str) -> 'Config':
ini_config = ConfigParser() ini_config = ConfigParser()
ini_config.read(fpath) ini_config.read(fpath)
if not ini_config.has_section('mypy_django_plugin'): if not ini_config.has_section('mypy_django_plugin'):

View File

@@ -9,8 +9,8 @@ from mypy.types import Instance, Type, TypeType
from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.config import Config from mypy_django_plugin.config import Config
from mypy_django_plugin.plugins import init_create
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class
from mypy_django_plugin.plugins.init_create import redefine_and_typecheck_model_init, redefine_and_typecheck_model_create
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations, get_string_value_from_expr from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations, get_string_value_from_expr
from mypy_django_plugin.plugins.models import process_model_class from mypy_django_plugin.plugins.models import process_model_class
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
@@ -81,7 +81,6 @@ def return_user_model_hook(ctx: FunctionContext) -> Type:
return TypeType(Instance(model_info, [])) return TypeType(Instance(model_info, []))
class DjangoPlugin(Plugin): class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None: def __init__(self, options: Options) -> None:
super().__init__(options) super().__init__(options)
@@ -150,15 +149,16 @@ class DjangoPlugin(Plugin):
if sym and isinstance(sym.node, TypeInfo): if sym and isinstance(sym.node, TypeInfo):
if sym.node.has_base(helpers.FIELD_FULLNAME): if sym.node.has_base(helpers.FIELD_FULLNAME):
return record_field_properties_into_outer_model_class return record_field_properties_into_outer_model_class
if sym.node.metadata.get('django', {}).get('generated_init'): if sym.node.metadata.get('django', {}).get('generated_init'):
return redefine_and_typecheck_model_init return init_create.redefine_and_typecheck_model_init
def get_method_hook(self, fullname: str def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]: ) -> Optional[Callable[[MethodContext], Type]]:
manager_classes = self._get_current_manager_bases() manager_classes = self._get_current_manager_bases()
class_fullname, _, method_name = fullname.rpartition('.') class_fullname, _, method_name = fullname.rpartition('.')
if class_fullname in manager_classes and method_name == 'create': if class_fullname in manager_classes and method_name == 'create':
return redefine_and_typecheck_model_create return init_create.redefine_and_typecheck_model_create
if fullname in {'django.apps.registry.Apps.get_model', if fullname in {'django.apps.registry.Apps.get_model',
'django.db.migrations.state.StateApps.get_model'}: 'django.db.migrations.state.StateApps.get_model'}:

View File

@@ -1,4 +1,4 @@
from typing import Dict, Optional, Set, cast, Any from typing import Any, Dict, Optional, Set, cast
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.nodes import FuncDef, TypeInfo, Var from mypy.nodes import FuncDef, TypeInfo, Var
@@ -157,26 +157,31 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, T
expected_types['id'] = ctx.api.named_generic_type('builtins.int', []) expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
expected_types['pk'] = primary_key_type expected_types['pk'] = primary_key_type
for base in model.mro: for base in model.mro:
for name, sym in base.names.items(): for name, sym in base.names.items():
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): # do not redefine special attrs
tp = sym.node.type if name in {'_meta', 'pk'}:
field_type = extract_field_setter_type(tp) continue
if field_type is None: if isinstance(sym.node, Var):
continue if isinstance(sym.node.type, Instance):
tp = sym.node.type
field_type = extract_field_setter_type(tp)
if field_type is None:
continue
choices_type_fullname = extract_choices_type(model, name) choices_type_fullname = extract_choices_type(model, name)
if choices_type_fullname: if choices_type_fullname:
field_type = UnionType([field_type, ctx.api.named_generic_type(choices_type_fullname, [])]) field_type = UnionType([field_type, ctx.api.named_generic_type(choices_type_fullname, [])])
if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}: if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
ref_to_model = tp.args[0] ref_to_model = tp.args[0]
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME): if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
primary_key_type = extract_primary_key_type(ref_to_model.type) primary_key_type = extract_primary_key_type(ref_to_model.type)
if not primary_key_type: if not primary_key_type:
primary_key_type = AnyType(TypeOfAny.special_form) primary_key_type = AnyType(TypeOfAny.special_form)
expected_types[name + '_id'] = primary_key_type expected_types[name + '_id'] = primary_key_type
if field_type: if field_type:
expected_types[name] = field_type expected_types[name] = field_type
elif isinstance(sym.node.type, AnyType):
expected_types[name] = sym.node.type
return expected_types return expected_types

View File

@@ -155,3 +155,19 @@ class MyModel(models.Model):
day = models.CharField(max_length=3, choices=((1, 'Fri'), (2, 'Sat'))) day = models.CharField(max_length=3, choices=((1, 'Fri'), (2, 'Sat')))
MyModel(day=1) MyModel(day=1)
[out] [out]
[CASE if_there_is_no_data_for_base_classes_of_fields_and_ignore_unresolved_attributes_set_to_true_to_not_fail]
from decimal import Decimal
from django.db import models
from fields2 import MoneyField
class InvoiceRow(models.Model):
base_amount = MoneyField(max_digits=10, decimal_places=2)
vat_rate = models.DecimalField(max_digits=10, decimal_places=2)
InvoiceRow(1, Decimal(0), Decimal(0))
InvoiceRow(base_amount=Decimal(0), vat_rate=Decimal(0))
InvoiceRow.objects.create(base_amount=Decimal(0), vat_rate=Decimal(0))
[out]
main:3: error: Cannot find module named 'fields2'
main:3: note: See https://mypy.readthedocs.io/en/latest/running_mypy.html#missing-imports