From 916df1efb61e9a7b9bd692c80bf076abffcae674 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Fri, 8 Feb 2019 17:16:03 +0300 Subject: [PATCH] add Model.__init__ typechecking --- django-stubs/contrib/admin/models.pyi | 34 ++--- django-stubs/contrib/auth/models.pyi | 46 +++---- .../contrib/postgres/fields/__init__.pyi | 1 + .../contrib/postgres/fields/array.pyi | 8 +- .../contrib/postgres/fields/hstore.pyi | 17 +++ django-stubs/db/models/fields/__init__.pyi | 2 +- django-stubs/db/models/query.pyi | 15 +-- django-stubs/forms/widgets.pyi | 30 ++--- mypy_django_plugin/helpers.py | 119 +++++++++++++++++- mypy_django_plugin/main.py | 42 +++++++ mypy_django_plugin/plugins/fields.py | 9 +- mypy_django_plugin/plugins/models.py | 20 ++- mypy_django_plugin/plugins/related_fields.py | 12 +- scripts/typecheck_tests.py | 4 +- test-data/typecheck/fields.test | 2 +- test-data/typecheck/model_create.test | 106 ++++++++++++++++ 16 files changed, 359 insertions(+), 108 deletions(-) create mode 100644 django-stubs/contrib/postgres/fields/hstore.pyi create mode 100644 test-data/typecheck/model_create.test diff --git a/django-stubs/contrib/admin/models.pyi b/django-stubs/contrib/admin/models.pyi index 9796299..f08d362 100644 --- a/django-stubs/contrib/admin/models.pyi +++ b/django-stubs/contrib/admin/models.pyi @@ -1,20 +1,17 @@ -import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from uuid import UUID -from django.db import models +from django.contrib.contenttypes.models import ContentType from django.db.models.base import Model +from django.db import models + ADDITION: int CHANGE: int DELETION: int ACTION_FLAG_CHOICES: Any class LogEntryManager(models.Manager): - creation_counter: int - model: None - name: None - use_in_migrations: bool = ... def log_action( self, user_id: int, @@ -22,23 +19,18 @@ class LogEntryManager(models.Manager): object_id: Union[int, str, UUID], object_repr: str, action_flag: int, - change_message: Union[ - Dict[str, Dict[str, List[str]]], List[Dict[str, Dict[str, Union[List[str], str]]]], str - ] = ..., + change_message: Any = ..., ) -> LogEntry: ... class LogEntry(models.Model): - content_type_id: int - id: None - user_id: int - action_time: datetime.datetime = ... - user: Any = ... - content_type: Any = ... - object_id: str = ... - object_repr: str = ... - action_flag: int = ... - change_message: str = ... - objects: Any = ... + action_time: models.DateTimeField = ... + user: models.ForeignKey = ... + content_type: models.ForeignKey[ContentType] = ... + object_id: models.TextField = ... + object_repr: models.CharField = ... + action_flag: models.PositiveSmallIntegerField = ... + change_message: models.TextField = ... + objects: LogEntryManager = ... def is_addition(self) -> bool: ... def is_change(self) -> bool: ... def is_deletion(self) -> bool: ... diff --git a/django-stubs/contrib/auth/models.pyi b/django-stubs/contrib/auth/models.pyi index e7e3995..e3ea354 100644 --- a/django-stubs/contrib/auth/models.pyi +++ b/django-stubs/contrib/auth/models.pyi @@ -1,46 +1,33 @@ -import datetime from typing import Any, List, Optional, Set, Tuple, Type, Union from django.contrib.auth.base_user import AbstractBaseUser as AbstractBaseUser, BaseUserManager as BaseUserManager +from django.contrib.contenttypes.models import ContentType from django.db.models.manager import EmptyManager +from django.contrib.auth.validators import UnicodeUsernameValidator from django.db import models def update_last_login(sender: Type[AbstractBaseUser], user: AbstractBaseUser, **kwargs: Any) -> None: ... class PermissionManager(models.Manager): - creation_counter: int - model: None - name: None - use_in_migrations: bool = ... def get_by_natural_key(self, codename: str, app_label: str, model: str) -> Permission: ... class Permission(models.Model): content_type_id: int - id: int - name: str = ... - content_type: Any = ... + name: models.CharField = ... + content_type: models.ForeignKey[ContentType] = ... codename: str = ... def natural_key(self) -> Tuple[str, str, str]: ... class GroupManager(models.Manager): - creation_counter: int - model: None - name: None - use_in_migrations: bool = ... def get_by_natural_key(self, name: str) -> Group: ... class Group(models.Model): - id: None - name: str = ... - permissions: Any = ... + name: models.CharField = ... + permissions: models.ManyToManyField[Permission] = ... def natural_key(self): ... class UserManager(BaseUserManager): - creation_counter: int - model: None - name: None - use_in_migrations: bool = ... def create_user( self, username: str, email: Optional[str] = ..., password: Optional[str] = ..., **extra_fields: Any ) -> AbstractUser: ... @@ -49,9 +36,9 @@ class UserManager(BaseUserManager): ) -> AbstractBaseUser: ... class PermissionsMixin(models.Model): - is_superuser: Any = ... - groups: Any = ... - user_permissions: Any = ... + is_superuser: models.BooleanField = ... + groups: models.ManyToManyField[Group] = ... + user_permissions: models.ManyToManyField[Permission] = ... def get_group_permissions(self, obj: None = ...) -> Set[str]: ... def get_all_permissions(self, obj: Optional[str] = ...) -> Set[str]: ... def has_perm(self, perm: Union[Tuple[str, Any], str], obj: Optional[str] = ...) -> bool: ... @@ -59,14 +46,13 @@ class PermissionsMixin(models.Model): def has_module_perms(self, app_label: str) -> bool: ... class AbstractUser(AbstractBaseUser, PermissionsMixin): # type: ignore - is_superuser: bool - username_validator: Any = ... - username: str = ... - first_name: str = ... - last_name: str = ... - email: str = ... - is_staff: bool = ... - date_joined: datetime.datetime = ... + username_validator: UnicodeUsernameValidator = ... + username: models.CharField = ... + first_name: models.CharField = ... + last_name: models.CharField = ... + email: models.EmailField = ... + is_staff: models.BooleanField = ... + date_joined: models.DateTimeField = ... EMAIL_FIELD: str = ... USERNAME_FIELD: str = ... def clean(self) -> None: ... diff --git a/django-stubs/contrib/postgres/fields/__init__.pyi b/django-stubs/contrib/postgres/fields/__init__.pyi index 58d3fbc..3c93c18 100644 --- a/django-stubs/contrib/postgres/fields/__init__.pyi +++ b/django-stubs/contrib/postgres/fields/__init__.pyi @@ -8,3 +8,4 @@ from .ranges import ( DateRangeField as DateRangeField, DateTimeRangeField as DateTimeRangeField, ) +from .hstore import HStoreField as HStoreField diff --git a/django-stubs/contrib/postgres/fields/array.pyi b/django-stubs/contrib/postgres/fields/array.pyi index c7f7510..6fa9488 100644 --- a/django-stubs/contrib/postgres/fields/array.pyi +++ b/django-stubs/contrib/postgres/fields/array.pyi @@ -13,14 +13,8 @@ class ArrayField(CheckFieldDefaultMixin, Field, Generic[_T]): default_validators: Any = ... from_db_value: Any = ... def __init__(self, base_field: _T, size: Optional[int] = ..., **kwargs: Any) -> None: ... - def check(self, **kwargs: Any) -> List[Any]: ... @property def description(self): ... - def get_db_prep_value(self, value: Any, connection: Any, prepared: bool = ...): ... - def to_python(self, value: Any): ... - def value_to_string(self, obj: Any): ... def get_transform(self, name: Any): ... - def validate(self, value: Any, model_instance: Any) -> None: ... - def run_validators(self, value: Any) -> None: ... - def __set__(self, instance, value: Sequence[_T]): ... + def __set__(self, instance, value: Sequence[_T]) -> None: ... def __get__(self, instance, owner) -> List[_T]: ... diff --git a/django-stubs/contrib/postgres/fields/hstore.pyi b/django-stubs/contrib/postgres/fields/hstore.pyi new file mode 100644 index 0000000..e94ea6a --- /dev/null +++ b/django-stubs/contrib/postgres/fields/hstore.pyi @@ -0,0 +1,17 @@ +from typing import Any + +from django.db.models import Field, Transform +from .mixins import CheckFieldDefaultMixin + +class HStoreField(CheckFieldDefaultMixin, Field): + def get_transform(self, name) -> Any: ... + +class KeyTransform(Transform): + def __init__(self, key_name: str, *args: Any, **kwargs: Any): ... + +class KeyTransformFactory: + def __init__(self, key_name: str): ... + def __call__(self, *args, **kwargs) -> KeyTransform: ... + +class KeysTransform(Transform): ... +class ValuesTransform(Transform): ... diff --git a/django-stubs/db/models/fields/__init__.pyi b/django-stubs/db/models/fields/__init__.pyi index 71a71fd..fc11c80 100644 --- a/django-stubs/db/models/fields/__init__.pyi +++ b/django-stubs/db/models/fields/__init__.pyi @@ -1,6 +1,6 @@ import uuid from datetime import date, time, datetime, timedelta -from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type +from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar import decimal from django.db.models import Model diff --git a/django-stubs/db/models/query.pyi b/django-stubs/db/models/query.pyi index 115b2d4..7f0f44c 100644 --- a/django-stubs/db/models/query.pyi +++ b/django-stubs/db/models/query.pyi @@ -95,13 +95,14 @@ class QuerySet(Iterable[_T], Sized): def raw( self, raw_query: str, params: Any = ..., translations: Optional[Dict[str, str]] = ..., using: None = ... ) -> RawQuerySet: ... - def values(self, *fields: Union[str, Combinable], **expressions: Any) -> ValuesIterable: ... - @overload - def values_list(self, *fields: Union[str, Combinable], named: Literal[True]) -> NamedValuesListIterable: ... - @overload - def values_list(self, *fields: Union[str, Combinable], flat: Literal[True]) -> FlatValuesListIterable: ... - @overload - def values_list(self, *fields: Union[str, Combinable]) -> ValuesListIterable: ... + def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet: ... + def values_list(self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...) -> QuerySet: ... + # @overload + # def values_list(self, *fields: Union[str, Combinable], named: Literal[True]) -> NamedValuesListIterable: ... + # @overload + # def values_list(self, *fields: Union[str, Combinable], flat: Literal[True]) -> FlatValuesListIterable: ... + # @overload + # def values_list(self, *fields: Union[str, Combinable]) -> ValuesListIterable: ... def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet: ... def datetimes(self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...) -> QuerySet: ... def none(self) -> QuerySet[_T]: ... diff --git a/django-stubs/forms/widgets.pyi b/django-stubs/forms/widgets.pyi index e92c22b..b4a7482 100644 --- a/django-stubs/forms/widgets.pyi +++ b/django-stubs/forms/widgets.pyi @@ -1,7 +1,7 @@ from datetime import time from decimal import Decimal from itertools import chain -from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, Iterable +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, Iterable, Sequence from django.contrib.admin.options import BaseModelAdmin from django.core.files.base import File @@ -114,6 +114,8 @@ class CheckboxInput(Input): check_test: Callable = ... def __init__(self, attrs: Optional[Dict[str, str]] = ..., check_test: Optional[Callable] = ...) -> None: ... +_OptAttrs = Dict[str, Any] + class ChoiceWidget(Widget): allow_multiple_selected: bool = ... input_type: Optional[str] = ... @@ -123,17 +125,9 @@ class ChoiceWidget(Widget): checked_attribute: Any = ... option_inherits_attrs: bool = ... choices: List[List[Union[int, str]]] = ... - def __init__( - self, - attrs: Optional[Dict[str, Union[bool, str]]] = ..., - choices: Union[ - Iterator[Any], List[List[Union[int, str]]], List[Tuple[Union[time, int], int]], List[int], Tuple - ] = ..., - ) -> None: ... - def options(self, name: str, value: List[str], attrs: Dict[str, Union[bool, str]] = ...) -> None: ... - def optgroups( - self, name: str, value: List[str], attrs: Optional[Dict[str, Union[bool, str]]] = ... - ) -> List[Tuple[Optional[str], List[Dict[str, Union[Dict[str, Union[bool, str]], time, int, str]]], int]]: ... + def __init__(self, attrs: Optional[_OptAttrs] = ..., choices: Sequence[Tuple[Any, Any]] = ...) -> None: ... + def options(self, name: str, value: List[str], attrs: Optional[_OptAttrs] = ...) -> None: ... + def optgroups(self, name: str, value: List[str], attrs: Optional[_OptAttrs] = ...) -> Any: ... def create_option( self, name: str, @@ -142,8 +136,8 @@ class ChoiceWidget(Widget): selected: Union[Set[str], bool], index: int, subindex: Optional[int] = ..., - attrs: Optional[Dict[str, Union[bool, str]]] = ..., - ) -> Dict[str, Union[Dict[str, Union[bool, str]], Dict[str, bool], Set[str], time, int, str]]: ... + attrs: Optional[_OptAttrs] = ..., + ) -> Dict[str, Any]: ... def id_for_label(self, id_: str, index: str = ...) -> str: ... class Select(ChoiceWidget): @@ -171,11 +165,7 @@ class CheckboxSelectMultiple(ChoiceWidget): class MultiWidget(Widget): template_name: str = ... widgets: List[Widget] = ... - def __init__( - self, - widgets: Union[List[Type[DateTimeBaseInput]], Tuple[Union[Type[TextInput], Input]]], - attrs: Optional[Dict[str, str]] = ..., - ) -> None: ... + def __init__(self, widgets: Sequence[Widget], attrs: Optional[_OptAttrs] = ...) -> None: ... @property def is_hidden(self) -> bool: ... def decompress(self, value: Any) -> Optional[Any]: ... @@ -218,7 +208,7 @@ class SelectDateWidget(Widget): day_none_value: Any = ... def __init__( self, - attrs: None = ..., + attrs: Optional[_OptAttrs] = ..., years: Optional[Union[Tuple[Union[int, str]], range]] = ..., months: None = ..., empty_label: Optional[Union[Tuple[str, str], str]] = ..., diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index f206148..9dfb43c 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -1,9 +1,13 @@ import typing from typing import Dict, Optional -from mypy.nodes import MypyFile, TypeInfo, ImportedName, SymbolNode +from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var +from mypy.plugin import FunctionContext +from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' +FIELD_FULLNAME = 'django.db.models.fields.Field' +GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey' FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey' ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField' MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField' @@ -78,3 +82,116 @@ def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) if sym is None: return None return sym.node + + +def parse_bool(expr: Expression) -> Optional[bool]: + if isinstance(expr, NameExpr): + if expr.fullname == 'builtins.True': + return True + if expr.fullname == 'builtins.False': + return False + return None + + +def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]): + return Instance(instance.type, args=new_typevars) + + +def fill_typevars_with_any(instance: Instance) -> Type: + return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)]) + + +def extract_typevar_value(tp: Instance, typevar_name: str): + return tp.args[tp.type.type_vars.index(typevar_name)] + + +def extract_field_setter_type(tp: Instance) -> Optional[Type]: + if tp.type.has_base(FIELD_FULLNAME): + set_method = tp.type.get_method('__set__') + if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType): + if 'value' in set_method.type.arg_names: + set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')] + if isinstance(set_value_type, Instance): + typevar_values: typing.List[Type] = [] + for typevar_arg in set_value_type.args: + if isinstance(typevar_arg, TypeVarType): + typevar_values.append(extract_typevar_value(tp, typevar_arg.name)) + # if there are typevars, extract from + set_value_type = reparametrize_with(set_value_type, typevar_values) + return set_value_type + + get_method = tp.type.get_method('__get__') + if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType): + return get_method.type.ret_type + # GenericForeignKey + if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME): + return AnyType(TypeOfAny.special_form) + return None + + +def extract_primary_key_type(model: TypeInfo) -> Optional[Type]: + # only primary keys defined in current class for now + for sym in model.names.values(): + if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): + tp = sym.node.type + if tp.type.metadata.get('django', {}).get('defined_as_primary_key'): + field_type = extract_field_setter_type(tp) + return field_type + return None + + +def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]: + expected_types: Dict[str, Type] = {} + for base in model.mro: + for name, sym in base.names.items(): + if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): + tp = sym.node.type + field_type = extract_field_setter_type(tp) + if tp.type.fullname() == FOREIGN_KEY_FULLNAME: + ref_to_model = tp.args[0] + if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME): + primary_key_type = extract_primary_key_type(ref_to_model.type) + if not primary_key_type: + primary_key_type = AnyType(TypeOfAny.special_form) + expected_types[name + '_id'] = primary_key_type + if field_type: + expected_types[name] = field_type + + primary_key_type = extract_primary_key_type(model) + if not primary_key_type: + # no explicit primary key, set pk to Any and add id + primary_key_type = AnyType(TypeOfAny.special_form) + expected_types['id'] = ctx.api.named_generic_type('builtins.int', []) + + expected_types['pk'] = primary_key_type + return expected_types + + +def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]: + """Return the expression for the specific argument. + + This helper should only be used with non-star arguments. + """ + if name not in ctx.callee_arg_names: + return None + idx = ctx.callee_arg_names.index(name) + args = ctx.args[idx] + if len(args) != 1: + # Either an error or no value passed. + return None + return args[0] + + +def get_argument_type_by_name(ctx: FunctionContext, name: str) -> Optional[Type]: + """Return the type for the specific argument. + + This helper should only be used with non-star arguments. + """ + if name not in ctx.callee_arg_names: + return None + idx = ctx.callee_arg_names.index(name) + arg_types = ctx.arg_types[idx] + if len(arg_types) != 1: + # Either an error or no value passed. + return None + return arg_types[0] diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 13299ce..38a49ef 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -8,6 +8,7 @@ from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin from mypy.types import Instance, Type from mypy_django_plugin import helpers, monkeypatch +from mypy_django_plugin.helpers import parse_bool from mypy_django_plugin.plugins.fields import determine_type_of_array_field from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations from mypy_django_plugin.plugins.models import process_model_class @@ -54,6 +55,40 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type: return ret +def redefine_model_init(ctx: FunctionContext) -> Type: + assert isinstance(ctx.default_return_type, Instance) + + api = cast(TypeChecker, ctx.api) + model: TypeInfo = ctx.default_return_type.type + + expected_types = helpers.extract_expected_types(ctx, model) + for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]): + if actual_name is None: + # We can't check kwargs reliably. + continue + if actual_name not in expected_types: + ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name, + model.name()), + ctx.context) + continue + api.check_subtype(actual_type, expected_types[actual_name], + ctx.context, + 'Incompatible type for "{}" of "{}"'.format(actual_name, + model.name()), + 'got', 'expected') + return ctx.default_return_type + + +def set_primary_key_marking(ctx: FunctionContext) -> Type: + primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key') + if primary_key_arg: + is_primary_key = parse_bool(primary_key_arg) + if is_primary_key: + info = ctx.default_return_type.type + info.metadata.setdefault('django', {})['defined_as_primary_key'] = True + return ctx.default_return_type + + class DjangoPlugin(Plugin): def __init__(self, options: Options) -> None: super().__init__(options) @@ -105,6 +140,13 @@ class DjangoPlugin(Plugin): if fullname in manager_bases: return determine_proper_manager_type + sym = self.lookup_fully_qualified(fullname) + if sym and isinstance(sym.node, TypeInfo): + if sym.node.has_base(helpers.FIELD_FULLNAME): + return set_primary_key_marking + if sym.node.metadata.get('django', {}).get('generated_init'): + return redefine_model_init + def get_method_hook(self, fullname: str ) -> Optional[Callable[[MethodContext], Type]]: if fullname in {'django.apps.registry.Apps.get_model', diff --git a/mypy_django_plugin/plugins/fields.py b/mypy_django_plugin/plugins/fields.py index 539b0d3..f99baed 100644 --- a/mypy_django_plugin/plugins/fields.py +++ b/mypy_django_plugin/plugins/fields.py @@ -1,13 +1,12 @@ from mypy.plugin import FunctionContext from mypy.types import Type, Instance +from mypy_django_plugin import helpers + def determine_type_of_array_field(ctx: FunctionContext) -> Type: - if 'base_field' not in ctx.callee_arg_names: - return ctx.default_return_type - - base_field_arg_type = ctx.arg_types[ctx.callee_arg_names.index('base_field')][0] - if not isinstance(base_field_arg_type, Instance): + base_field_arg_type = helpers.get_argument_type_by_name(ctx, 'base_field') + if not base_field_arg_type or not isinstance(base_field_arg_type, Instance): return ctx.default_return_type get_method = base_field_arg_type.type.get_method('__get__') diff --git a/mypy_django_plugin/plugins/models.py b/mypy_django_plugin/plugins/models.py index 91e1d80..823ddc7 100644 --- a/mypy_django_plugin/plugins/models.py +++ b/mypy_django_plugin/plugins/models.py @@ -2,11 +2,12 @@ from abc import ABCMeta, abstractmethod from typing import Dict, Iterator, Optional, Tuple, cast import dataclasses -from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, \ - StrExpr, SymbolTableNode, TypeInfo, Var +from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, \ + MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, Argument, ARG_STAR2 from mypy.plugin import ClassDefContext +from mypy.plugins.common import add_method from mypy.semanal import SemanticAnalyzerPass2 -from mypy.types import Instance +from mypy.types import Instance, AnyType, TypeOfAny, NoneTyp from mypy_django_plugin import helpers @@ -199,16 +200,27 @@ def extract_ref_to_fullname(rvalue_expr: CallExpr, return None +def add_dummy_init_method(ctx: ClassDefContext) -> None: + any = AnyType(TypeOfAny.special_form) + var = Var('kwargs', any) + kw_arg = Argument(variable=var, type_annotation=any, initializer=None, kind=ARG_STAR2) + add_method(ctx, '__init__', [kw_arg], NoneTyp()) + # mark as model class + ctx.cls.info.metadata.setdefault('django', {})['generated_init'] = True + + def process_model_class(ctx: ClassDefContext) -> None: initializers = [ InjectAnyAsBaseForNestedMeta, AddDefaultObjectsManager, AddIdAttributeIfPrimaryKeyTrueIsNotSet, SetIdAttrsForRelatedFields, - AddRelatedManagers + AddRelatedManagers, ] for initializer_cls in initializers: initializer_cls.from_ctx(ctx).run() + add_dummy_init_method(ctx) + # allow unspecified attributes for now ctx.cls.info.fallback_to_any = True diff --git a/mypy_django_plugin/plugins/related_fields.py b/mypy_django_plugin/plugins/related_fields.py index df3e4c9..2fd0aac 100644 --- a/mypy_django_plugin/plugins/related_fields.py +++ b/mypy_django_plugin/plugins/related_fields.py @@ -1,20 +1,12 @@ -import typing from typing import Optional, cast from mypy.checker import TypeChecker from mypy.nodes import StrExpr, TypeInfo from mypy.plugin import FunctionContext -from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny +from mypy.types import CallableType, Instance, Type from mypy_django_plugin import helpers - - -def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]): - return Instance(instance.type, args=new_typevars) - - -def fill_typevars_with_any(instance: Instance) -> Type: - return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)]) +from mypy_django_plugin.helpers import fill_typevars_with_any, reparametrize_with def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: diff --git a/scripts/typecheck_tests.py b/scripts/typecheck_tests.py index 0343a79..db0cbac 100644 --- a/scripts/typecheck_tests.py +++ b/scripts/typecheck_tests.py @@ -29,6 +29,7 @@ IGNORED_ERRORS = { 'Invalid value for a to= parameter', 'already defined (possibly by an import)', 'Cannot assign to a type', + re.compile(r'Cannot assign to class variable "[a-z_]+" via instance'), # forms <-> models plugin support '"Model" has no attribute', re.compile(r'Cannot determine type of \'(objects|stuff)\''), @@ -73,7 +74,8 @@ IGNORED_ERRORS = { 'admin_views': [ 'Argument 1 to "FileWrapper" has incompatible type "StringIO"; expected "IO[bytes]"', 'Incompatible types in assignment', - '"object" not callable' + '"object" not callable', + 'Incompatible type for "pk" of "Collector" (got "int", expected "str")' ], 'aggregation': [ 'Incompatible types in assignment (expression has type "QuerySet[Any]", variable has type "List[Any]")', diff --git a/test-data/typecheck/fields.test b/test-data/typecheck/fields.test index 047677b..c3c2d21 100644 --- a/test-data/typecheck/fields.test +++ b/test-data/typecheck/fields.test @@ -93,4 +93,4 @@ class Abstract(models.Model): abstract = True class User(Abstract): id = models.AutoField(primary_key=True) -[out] +[out] \ No newline at end of file diff --git a/test-data/typecheck/model_create.test b/test-data/typecheck/model_create.test new file mode 100644 index 0000000..ecf3b50 --- /dev/null +++ b/test-data/typecheck/model_create.test @@ -0,0 +1,106 @@ +[CASE arguments_to_init_unexpected_attributes] +from django.db import models + +class MyUser(models.Model): + pass +user = MyUser(name=1, age=12) +[out] +main:5: error: Unexpected attribute "name" for model "MyUser" +main:5: error: Unexpected attribute "age" for model "MyUser" + +[CASE arguments_to_init_from_class_incompatible_type] +from django.db import models + +class MyUser(models.Model): + name = models.CharField(max_length=100) + age = models.IntegerField() +user = MyUser(name=1, age=12) +[out] +main:6: error: Incompatible type for "name" of "MyUser" (got "int", expected "str") + +[CASE arguments_to_init_combined_from_base_classes] +from django.db import models + +class BaseUser(models.Model): + name = models.CharField(max_length=100) + age = models.IntegerField() +class ChildUser(BaseUser): + lastname = models.CharField(max_length=100) +user = ChildUser(name='Max', age=12, lastname='Lastname') +[out] + +[CASE fields_from_abstract_user_propagate_to_init] +from django.contrib.auth.models import AbstractUser + +class MyUser(AbstractUser): + pass +user = MyUser(username='maxim', password='password', first_name='Max', last_name='MaxMax') +[out] + +[CASE generic_foreign_key_field_no_typechecking] +from django.db import models +from django.contrib.contenttypes.fields import GenericForeignKey + +class MyUser(models.Model): + content_object = GenericForeignKey() + +user = MyUser(content_object=1) +[out] + +[CASE pk_refers_to_primary_key_and_could_be_passed_to_init] +from django.db import models + +class MyUser1(models.Model): + mypk = models.CharField(primary_key=True) +class MyUser2(models.Model): + pass +user2 = MyUser1(pk='hello') +user3= MyUser2(pk=1) +[out] + +[CASE typechecking_of_pk] +from django.db import models + +class MyUser1(models.Model): + mypk = models.CharField(primary_key=True) +user = MyUser1(pk=1) # E: Incompatible type for "pk" of "MyUser1" (got "int", expected "str") +[out] + +[CASE can_set_foreign_key_by_its_primary_key] +from django.db import models + +class Publisher(models.Model): + pass +class PublisherWithCharPK(models.Model): + id = models.CharField(max_length=100, primary_key=True) +class Book(models.Model): + publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE) + publisher_with_char_pk = models.ForeignKey(PublisherWithCharPK, on_delete=models.CASCADE) + +Book(publisher_id=1, publisher_with_char_pk_id='hello') +Book(publisher_id=1, publisher_with_char_pk_id=1) # E: Incompatible type for "publisher_with_char_pk_id" of "Book" (got "int", expected "str") +[out] + +[CASE setting_value_to_an_array_of_ints] +from typing import List, Tuple + +from django.db import models +from django.contrib.postgres.fields import ArrayField + +class MyModel(models.Model): + array = ArrayField(base_field=models.IntegerField()) +array_val: Tuple[int, ...] = (1,) +MyModel(array=array_val) +array_val2: List[int] = [1] +MyModel(array=array_val2) +array_val3: List[str] = ['hello'] +MyModel(array=array_val3) # E: Incompatible type for "array" of "MyModel" (got "List[str]", expected "Sequence[int]") +[out] + +[CASE if_no_explicit_primary_key_id_can_be_passed] +from django.db import models + +class MyModel(models.Model): + pass +MyModel(id=1) +[out] \ No newline at end of file