From 70378b8f40aa2b3f0513eaed1a06a2ebbf882505 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Wed, 13 Feb 2019 17:00:35 +0300 Subject: [PATCH] preserve fallback to Any for unrecognized field types for init/create --- mypy_django_plugin/config.py | 2 +- mypy_django_plugin/main.py | 8 ++--- mypy_django_plugin/plugins/init_create.py | 43 +++++++++++++---------- test-data/typecheck/model_init.test | 16 +++++++++ 4 files changed, 45 insertions(+), 24 deletions(-) diff --git a/mypy_django_plugin/config.py b/mypy_django_plugin/config.py index 34d53f6..6374482 100644 --- a/mypy_django_plugin/config.py +++ b/mypy_django_plugin/config.py @@ -10,7 +10,7 @@ class Config: ignore_missing_settings: bool = False @classmethod - def from_config_file(self, fpath: str) -> 'Config': + def from_config_file(cls, fpath: str) -> 'Config': ini_config = ConfigParser() ini_config.read(fpath) if not ini_config.has_section('mypy_django_plugin'): diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index a7ad2d8..42e8ede 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -9,8 +9,8 @@ from mypy.types import Instance, Type, TypeType from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin.config import Config +from mypy_django_plugin.plugins import init_create from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class -from mypy_django_plugin.plugins.init_create import redefine_and_typecheck_model_init, redefine_and_typecheck_model_create from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations, get_string_value_from_expr from mypy_django_plugin.plugins.models import process_model_class from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with @@ -81,7 +81,6 @@ def return_user_model_hook(ctx: FunctionContext) -> Type: return TypeType(Instance(model_info, [])) - class DjangoPlugin(Plugin): def __init__(self, options: Options) -> None: super().__init__(options) @@ -150,15 +149,16 @@ class DjangoPlugin(Plugin): if sym and isinstance(sym.node, TypeInfo): if sym.node.has_base(helpers.FIELD_FULLNAME): return record_field_properties_into_outer_model_class + if sym.node.metadata.get('django', {}).get('generated_init'): - return redefine_and_typecheck_model_init + return init_create.redefine_and_typecheck_model_init def get_method_hook(self, fullname: str ) -> Optional[Callable[[MethodContext], Type]]: manager_classes = self._get_current_manager_bases() class_fullname, _, method_name = fullname.rpartition('.') if class_fullname in manager_classes and method_name == 'create': - return redefine_and_typecheck_model_create + return init_create.redefine_and_typecheck_model_create if fullname in {'django.apps.registry.Apps.get_model', 'django.db.migrations.state.StateApps.get_model'}: diff --git a/mypy_django_plugin/plugins/init_create.py b/mypy_django_plugin/plugins/init_create.py index cef776a..f04679f 100644 --- a/mypy_django_plugin/plugins/init_create.py +++ b/mypy_django_plugin/plugins/init_create.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Set, cast, Any +from typing import Any, Dict, Optional, Set, cast from mypy.checker import TypeChecker from mypy.nodes import FuncDef, TypeInfo, Var @@ -157,26 +157,31 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, T expected_types['id'] = ctx.api.named_generic_type('builtins.int', []) expected_types['pk'] = primary_key_type - for base in model.mro: for name, sym in base.names.items(): - if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): - tp = sym.node.type - field_type = extract_field_setter_type(tp) - if field_type is None: - continue + # do not redefine special attrs + if name in {'_meta', 'pk'}: + continue + if isinstance(sym.node, Var): + if isinstance(sym.node.type, Instance): + tp = sym.node.type + field_type = extract_field_setter_type(tp) + if field_type is None: + continue - choices_type_fullname = extract_choices_type(model, name) - if choices_type_fullname: - field_type = UnionType([field_type, ctx.api.named_generic_type(choices_type_fullname, [])]) + choices_type_fullname = extract_choices_type(model, name) + if choices_type_fullname: + field_type = UnionType([field_type, ctx.api.named_generic_type(choices_type_fullname, [])]) - if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}: - ref_to_model = tp.args[0] - if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME): - primary_key_type = extract_primary_key_type(ref_to_model.type) - if not primary_key_type: - primary_key_type = AnyType(TypeOfAny.special_form) - expected_types[name + '_id'] = primary_key_type - if field_type: - expected_types[name] = field_type + if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}: + ref_to_model = tp.args[0] + if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME): + primary_key_type = extract_primary_key_type(ref_to_model.type) + if not primary_key_type: + primary_key_type = AnyType(TypeOfAny.special_form) + expected_types[name + '_id'] = primary_key_type + if field_type: + expected_types[name] = field_type + elif isinstance(sym.node.type, AnyType): + expected_types[name] = sym.node.type return expected_types diff --git a/test-data/typecheck/model_init.test b/test-data/typecheck/model_init.test index 4f280d8..a15ca46 100644 --- a/test-data/typecheck/model_init.test +++ b/test-data/typecheck/model_init.test @@ -155,3 +155,19 @@ class MyModel(models.Model): day = models.CharField(max_length=3, choices=((1, 'Fri'), (2, 'Sat'))) MyModel(day=1) [out] + +[CASE if_there_is_no_data_for_base_classes_of_fields_and_ignore_unresolved_attributes_set_to_true_to_not_fail] +from decimal import Decimal +from django.db import models +from fields2 import MoneyField + +class InvoiceRow(models.Model): + base_amount = MoneyField(max_digits=10, decimal_places=2) + vat_rate = models.DecimalField(max_digits=10, decimal_places=2) +InvoiceRow(1, Decimal(0), Decimal(0)) +InvoiceRow(base_amount=Decimal(0), vat_rate=Decimal(0)) +InvoiceRow.objects.create(base_amount=Decimal(0), vat_rate=Decimal(0)) +[out] +main:3: error: Cannot find module named 'fields2' +main:3: note: See https://mypy.readthedocs.io/en/latest/running_mypy.html#missing-imports +