From 6b7507206a0853c8488e048fe5f9b0c6a6ce1bf3 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Sun, 10 Feb 2019 04:32:27 +0300 Subject: [PATCH] fix couple edge cases with __init__ --- django-stubs/contrib/sites/models.pyi | 11 +-- django-stubs/db/models/base.pyi | 12 +++- django-stubs/db/models/fields/__init__.pyi | 8 +-- mypy_django_plugin/helpers.py | 78 ++++++++++++++-------- mypy_django_plugin/main.py | 53 ++++++++++----- mypy_django_plugin/plugins/models.py | 12 ++-- scripts/typecheck_tests.py | 12 +++- test-data/typecheck/model_create.test | 57 ++++++++++++++-- 8 files changed, 172 insertions(+), 71 deletions(-) diff --git a/django-stubs/contrib/sites/models.pyi b/django-stubs/contrib/sites/models.pyi index d8d224f..7b74d89 100644 --- a/django-stubs/contrib/sites/models.pyi +++ b/django-stubs/contrib/sites/models.pyi @@ -7,19 +7,14 @@ from django.db import models SITE_CACHE: Any class SiteManager(models.Manager): - creation_counter: int - model: None - name: None - use_in_migrations: bool = ... def get_current(self, request: Optional[HttpRequest] = ...) -> Site: ... def clear_cache(self) -> None: ... def get_by_natural_key(self, domain: str) -> Site: ... class Site(models.Model): - id: int - domain: str = ... - name: str = ... - objects: Any = ... + domain: models.CharField = ... + name: models.CharField = ... + objects: SiteManager = ... def natural_key(self) -> Tuple[str]: ... def clear_site_cache(sender: Type[Site], **kwargs: Any) -> None: ... diff --git a/django-stubs/db/models/base.pyi b/django-stubs/db/models/base.pyi index 2bd8330..4e1ca23 100644 --- a/django-stubs/db/models/base.pyi +++ b/django-stubs/db/models/base.pyi @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, ClassVar +from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, ClassVar, Sequence from django.db.models.manager import Manager @@ -22,8 +22,16 @@ class Model(metaclass=ModelBase): force_insert: bool = ..., force_update: bool = ..., using: Optional[str] = ..., - update_fields: Optional[Union[List[str], str]] = ..., + update_fields: Optional[Union[Sequence[str], str]] = ..., ) -> None: ... + def save_base( + self, + raw: bool = ..., + force_insert: bool = ..., + force_update: bool = ..., + using: Optional[str] = ..., + update_fields: Optional[Union[Sequence[str], str]] = ..., + ): ... def refresh_from_db(self: _Self, using: Optional[str] = ..., fields: Optional[List[str]] = ...) -> _Self: ... def get_deferred_fields(self) -> Set[str]: ... diff --git a/django-stubs/db/models/fields/__init__.pyi b/django-stubs/db/models/fields/__init__.pyi index fc11c80..c871813 100644 --- a/django-stubs/db/models/fields/__init__.pyi +++ b/django-stubs/db/models/fields/__init__.pyi @@ -62,7 +62,7 @@ class Field(RegisterLookupMixin): def to_python(self, value: Any) -> Any: ... class IntegerField(Field): - def __set__(self, instance, value: Union[int, F]) -> None: ... + def __set__(self, instance, value: Union[int, Combinable]) -> None: ... def __get__(self, instance, owner) -> int: ... class PositiveIntegerRelDbTypeMixin: @@ -231,7 +231,7 @@ class DateField(DateTimeCheckMixin, Field): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... - def __set__(self, instance, value: Any) -> None: ... + def __set__(self, instance, value: Union[str, date, Combinable]) -> None: ... def __get__(self, instance, owner) -> date: ... class TimeField(DateTimeCheckMixin, Field): @@ -257,11 +257,11 @@ class TimeField(DateTimeCheckMixin, Field): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... - def __set__(self, instance, value: Any) -> None: ... + def __set__(self, instance, value: Union[str, time, datetime, Combinable]) -> None: ... def __get__(self, instance, owner) -> time: ... class DateTimeField(DateField): - def __set__(self, instance, value: Any) -> None: ... + def __set__(self, instance, value: Union[str, date, datetime, Combinable]) -> None: ... def __get__(self, instance, owner) -> datetime: ... class UUIDField(Field): diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index 9dfb43c..65a0c14 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -1,9 +1,10 @@ import typing from typing import Dict, Optional -from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var +from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var, AssignmentStmt, \ + CallExpr from mypy.plugin import FunctionContext -from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType +from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' FIELD_FULLNAME = 'django.db.models.fields.Field' @@ -101,10 +102,23 @@ def fill_typevars_with_any(instance: Instance) -> Type: return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)]) -def extract_typevar_value(tp: Instance, typevar_name: str): +def extract_typevar_value(tp: Instance, typevar_name: str) -> Type: + if typevar_name in {'_T', '_T_co'}: + if '_T' in tp.type.type_vars: + return tp.args[tp.type.type_vars.index('_T')] + if '_T_co' in tp.type.type_vars: + return tp.args[tp.type.type_vars.index('_T_co')] return tp.args[tp.type.type_vars.index(typevar_name)] +def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance: + typevar_values: typing.List[Type] = [] + for typevar_arg in type_to_fill.args: + if isinstance(typevar_arg, TypeVarType): + typevar_values.append(extract_typevar_value(tp, typevar_arg.name)) + return reparametrize_with(type_to_fill, typevar_values) + + def extract_field_setter_type(tp: Instance) -> Optional[Type]: if tp.type.has_base(FIELD_FULLNAME): set_method = tp.type.get_method('__set__') @@ -112,13 +126,15 @@ def extract_field_setter_type(tp: Instance) -> Optional[Type]: if 'value' in set_method.type.arg_names: set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')] if isinstance(set_value_type, Instance): - typevar_values: typing.List[Type] = [] - for typevar_arg in set_value_type.args: - if isinstance(typevar_arg, TypeVarType): - typevar_values.append(extract_typevar_value(tp, typevar_arg.name)) - # if there are typevars, extract from - set_value_type = reparametrize_with(set_value_type, typevar_values) + set_value_type = fill_typevars(tp, set_value_type) return set_value_type + elif isinstance(set_value_type, UnionType): + items_no_typevars = [] + for item in set_value_type.items: + if isinstance(item, Instance): + item = fill_typevars(tp, item) + items_no_typevars.append(item) + return UnionType(items_no_typevars) get_method = tp.type.get_method('__get__') if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType): @@ -131,31 +147,20 @@ def extract_field_setter_type(tp: Instance) -> Optional[Type]: def extract_primary_key_type(model: TypeInfo) -> Optional[Type]: # only primary keys defined in current class for now - for sym in model.names.values(): - if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): - tp = sym.node.type - if tp.type.metadata.get('django', {}).get('defined_as_primary_key'): - field_type = extract_field_setter_type(tp) - return field_type + for stmt in model.defn.defs.body: + if isinstance(stmt, AssignmentStmt) and isinstance(stmt.rvalue, CallExpr): + name_expr = stmt.lvalues[0] + if isinstance(name_expr, NameExpr): + name = name_expr.name + if 'primary_key' in stmt.rvalue.arg_names: + is_primary_key = stmt.rvalue.args[stmt.rvalue.arg_names.index('primary_key')] + if is_primary_key: + return extract_field_setter_type(model.names[name].type) return None def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]: expected_types: Dict[str, Type] = {} - for base in model.mro: - for name, sym in base.names.items(): - if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): - tp = sym.node.type - field_type = extract_field_setter_type(tp) - if tp.type.fullname() == FOREIGN_KEY_FULLNAME: - ref_to_model = tp.args[0] - if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME): - primary_key_type = extract_primary_key_type(ref_to_model.type) - if not primary_key_type: - primary_key_type = AnyType(TypeOfAny.special_form) - expected_types[name + '_id'] = primary_key_type - if field_type: - expected_types[name] = field_type primary_key_type = extract_primary_key_type(model) if not primary_key_type: @@ -164,6 +169,21 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, T expected_types['id'] = ctx.api.named_generic_type('builtins.int', []) expected_types['pk'] = primary_key_type + + for base in model.mro: + for name, sym in base.names.items(): + if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): + tp = sym.node.type + field_type = extract_field_setter_type(tp) + if tp.type.fullname() in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}: + ref_to_model = tp.args[0] + if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME): + primary_key_type = extract_primary_key_type(ref_to_model.type) + if not primary_key_type: + primary_key_type = AnyType(TypeOfAny.special_form) + expected_types[name + '_id'] = primary_key_type + if field_type: + expected_types[name] = field_type return expected_types diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 8ff0216..d5224d4 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,6 +1,6 @@ import os from configparser import ConfigParser -from typing import Callable, Dict, Optional, cast +from typing import Callable, Dict, Optional, Set, cast from dataclasses import dataclass from mypy.checker import TypeChecker @@ -10,7 +10,6 @@ from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin from mypy.types import Instance, Type from mypy_django_plugin import helpers, monkeypatch -from mypy_django_plugin.helpers import parse_bool from mypy_django_plugin.plugins.fields import determine_type_of_array_field from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations from mypy_django_plugin.plugins.models import process_model_class @@ -57,6 +56,16 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type: return ret +def extract_base_pointer_args(model: TypeInfo) -> Set[str]: + pointer_args: Set[str] = set() + for base in model.bases: + if base.type.has_base(helpers.MODEL_CLASS_FULLNAME): + parent_name = base.type.name().lower() + pointer_args.add(f'{parent_name}_ptr') + pointer_args.add(f'{parent_name}_ptr_id') + return pointer_args + + def redefine_model_init(ctx: FunctionContext) -> Type: assert isinstance(ctx.default_return_type, Instance) @@ -64,9 +73,33 @@ def redefine_model_init(ctx: FunctionContext) -> Type: model: TypeInfo = ctx.default_return_type.type expected_types = helpers.extract_expected_types(ctx, model) - for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]): + # order is preserved, can use for positionals + positional_names = list(expected_types.keys()) + positional_names.remove('pk') + visited_positionals = set() + + # check positionals + for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])): + actual_pos_name = positional_names[i] + api.check_subtype(actual_pos_type, expected_types[actual_pos_name], + ctx.context, + 'Incompatible type for "{}" of "{}"'.format(actual_pos_name, + model.name()), + 'got', 'expected') + visited_positionals.add(actual_pos_name) + + # extract name of base models for _ptr + base_pointer_args = extract_base_pointer_args(model) + + # check kwargs + for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])): + if actual_name in base_pointer_args: + # parent_ptr args are not supported + continue + if actual_name in visited_positionals: + continue if actual_name is None: - # We can't check kwargs reliably. + # unpacked dict as kwargs is not supported continue if actual_name not in expected_types: ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name, @@ -81,16 +114,6 @@ def redefine_model_init(ctx: FunctionContext) -> Type: return ctx.default_return_type -def set_primary_key_marking(ctx: FunctionContext) -> Type: - primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key') - if primary_key_arg: - is_primary_key = parse_bool(primary_key_arg) - if is_primary_key: - info = ctx.default_return_type.type - info.metadata.setdefault('django', {})['defined_as_primary_key'] = True - return ctx.default_return_type - - @dataclass class Config: django_settings_module: Optional[str] = None @@ -171,8 +194,6 @@ class DjangoPlugin(Plugin): sym = self.lookup_fully_qualified(fullname) if sym and isinstance(sym.node, TypeInfo): - if sym.node.has_base(helpers.FIELD_FULLNAME): - return set_primary_key_marking if sym.node.metadata.get('django', {}).get('generated_init'): return redefine_model_init diff --git a/mypy_django_plugin/plugins/models.py b/mypy_django_plugin/plugins/models.py index 823ddc7..7f65960 100644 --- a/mypy_django_plugin/plugins/models.py +++ b/mypy_django_plugin/plugins/models.py @@ -3,7 +3,7 @@ from typing import Dict, Iterator, Optional, Tuple, cast import dataclasses from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, \ - MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, Argument, ARG_STAR2 + MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, Argument, ARG_STAR2, ARG_STAR from mypy.plugin import ClassDefContext from mypy.plugins.common import add_method from mypy.semanal import SemanticAnalyzerPass2 @@ -202,9 +202,13 @@ def extract_ref_to_fullname(rvalue_expr: CallExpr, def add_dummy_init_method(ctx: ClassDefContext) -> None: any = AnyType(TypeOfAny.special_form) - var = Var('kwargs', any) - kw_arg = Argument(variable=var, type_annotation=any, initializer=None, kind=ARG_STAR2) - add_method(ctx, '__init__', [kw_arg], NoneTyp()) + + pos_arg = Argument(variable=Var('args', any), + type_annotation=any, initializer=None, kind=ARG_STAR) + kw_arg = Argument(variable=Var('kwargs', any), + type_annotation=any, initializer=None, kind=ARG_STAR2) + + add_method(ctx, '__init__', [pos_arg, kw_arg], NoneTyp()) # mark as model class ctx.cls.info.metadata.setdefault('django', {})['generated_init'] = True diff --git a/scripts/typecheck_tests.py b/scripts/typecheck_tests.py index db0cbac..dd52320 100644 --- a/scripts/typecheck_tests.py +++ b/scripts/typecheck_tests.py @@ -101,7 +101,8 @@ IGNORED_ERRORS = { ], 'basic': [ 'Unexpected keyword argument "unknown_kwarg" for "refresh_from_db" of "Model"', - '"refresh_from_db" of "Model" defined here' + '"refresh_from_db" of "Model" defined here', + 'Unexpected attribute "foo" for model "Article"' ], 'builtin_server': [ 'has no attribute "getvalue"' @@ -174,6 +175,10 @@ IGNORED_ERRORS = { 'Unexpected keyword argument "name" for "Person"', 'Cannot assign multiple types to name "PersonTwoImages" without an explicit "Type[...]" annotation', ], + 'model_regress': [ + 'Too many arguments for "Worker"', + re.compile(r'Incompatible type for "[a-z]+" of "Worker" \(got "int", expected') + ], 'modeladmin': [ 'BandAdmin', ], @@ -200,6 +205,9 @@ IGNORED_ERRORS = { 'DummyJSONField', 'Argument "encoder" to "JSONField" has incompatible type "DjangoJSONEncoder"; expected "Optional[Type[JSONEncoder]]"' ], + 'properties': [ + re.compile('Unexpected attribute "(full_name|full_name_2)" for model "Person"') + ], 'requests': [ 'Incompatible types in assignment (expression has type "Dict[str, str]", variable has type "QueryDict")' ], @@ -420,7 +428,7 @@ TESTS_DIRS = [ 'model_package', 'model_regress', # not practical - # 'modeladmin', + 'modeladmin', # TODO: 'multiple_database', 'mutually_referential', 'nested_foreign_keys', diff --git a/test-data/typecheck/model_create.test b/test-data/typecheck/model_create.test index ecf3b50..bd6986a 100644 --- a/test-data/typecheck/model_create.test +++ b/test-data/typecheck/model_create.test @@ -16,7 +16,7 @@ class MyUser(models.Model): age = models.IntegerField() user = MyUser(name=1, age=12) [out] -main:6: error: Incompatible type for "name" of "MyUser" (got "int", expected "str") +main:6: error: Incompatible type for "name" of "MyUser" (got "int", expected "Union[str, Combinable]") [CASE arguments_to_init_combined_from_base_classes] from django.db import models @@ -53,7 +53,7 @@ from django.db import models class MyUser1(models.Model): mypk = models.CharField(primary_key=True) class MyUser2(models.Model): - pass + name = models.CharField(max_length=100) user2 = MyUser1(pk='hello') user3= MyUser2(pk=1) [out] @@ -63,7 +63,7 @@ from django.db import models class MyUser1(models.Model): mypk = models.CharField(primary_key=True) -user = MyUser1(pk=1) # E: Incompatible type for "pk" of "MyUser1" (got "int", expected "str") +user = MyUser1(pk=1) # E: Incompatible type for "pk" of "MyUser1" (got "int", expected "Union[str, Combinable]") [out] [CASE can_set_foreign_key_by_its_primary_key] @@ -78,7 +78,7 @@ class Book(models.Model): publisher_with_char_pk = models.ForeignKey(PublisherWithCharPK, on_delete=models.CASCADE) Book(publisher_id=1, publisher_with_char_pk_id='hello') -Book(publisher_id=1, publisher_with_char_pk_id=1) # E: Incompatible type for "publisher_with_char_pk_id" of "Book" (got "int", expected "str") +Book(publisher_id=1, publisher_with_char_pk_id=1) # E: Incompatible type for "publisher_with_char_pk_id" of "Book" (got "int", expected "Union[str, Combinable]") [out] [CASE setting_value_to_an_array_of_ints] @@ -100,7 +100,52 @@ MyModel(array=array_val3) # E: Incompatible type for "array" of "MyModel" (got [CASE if_no_explicit_primary_key_id_can_be_passed] from django.db import models +class MyModel(models.Model): + name = models.CharField(max_length=100) +MyModel(id=1, name='maxim') +[out] + +[CASE arguments_can_be_passed_as_positionals] +from django.db import models class MyModel(models.Model): pass -MyModel(id=1) -[out] \ No newline at end of file +MyModel(1) + +class MyModel2(models.Model): + name = models.CharField(max_length=100) +MyModel2(1, 'Maxim') +MyModel2(1, 12) # E: Incompatible type for "name" of "MyModel2" (got "int", expected "Union[str, Combinable]") +[out] + +[CASE arguments_passed_as_dictionary_unpacking_are_not_supported] +from django.db import models +class MyModel(models.Model): + name = models.CharField(max_length=100) +MyModel(**{'name': 'hello'}) +[out] + +[CASE pointer_to_parent_model_is_not_supported] +from django.db import models +class Place(models.Model): + pass +class Restaurant(Place): + pass +place = Place() +Restaurant(place_ptr=place) +Restaurant(place_ptr_id=place.id) +[out] + +[CASE extract_type_of_init_param_from_set_method] +from typing import Union +from datetime import time + +from django.db import models +class MyField(models.Field): + def __set__(self, instance, value: Union[str, time]) -> None: pass + def __get__(self, instance, owner) -> time: pass +class MyModel(models.Model): + field = MyField() +MyModel(field=time()) +MyModel(field='12:00') +MyModel(field=100) # E: Incompatible type for "field" of "MyModel" (got "int", expected "Union[str, time]") +