add Model.__init__ typechecking

This commit is contained in:
Maxim Kurnikov
2019-02-08 17:16:03 +03:00
parent dead370244
commit 916df1efb6
16 changed files with 359 additions and 108 deletions

View File

@@ -1,20 +1,17 @@
import datetime from typing import Any, Optional, Union
from typing import Any, Dict, List, Optional, Union
from uuid import UUID from uuid import UUID
from django.db import models from django.contrib.contenttypes.models import ContentType
from django.db.models.base import Model from django.db.models.base import Model
from django.db import models
ADDITION: int ADDITION: int
CHANGE: int CHANGE: int
DELETION: int DELETION: int
ACTION_FLAG_CHOICES: Any ACTION_FLAG_CHOICES: Any
class LogEntryManager(models.Manager): class LogEntryManager(models.Manager):
creation_counter: int
model: None
name: None
use_in_migrations: bool = ...
def log_action( def log_action(
self, self,
user_id: int, user_id: int,
@@ -22,23 +19,18 @@ class LogEntryManager(models.Manager):
object_id: Union[int, str, UUID], object_id: Union[int, str, UUID],
object_repr: str, object_repr: str,
action_flag: int, action_flag: int,
change_message: Union[ change_message: Any = ...,
Dict[str, Dict[str, List[str]]], List[Dict[str, Dict[str, Union[List[str], str]]]], str
] = ...,
) -> LogEntry: ... ) -> LogEntry: ...
class LogEntry(models.Model): class LogEntry(models.Model):
content_type_id: int action_time: models.DateTimeField = ...
id: None user: models.ForeignKey = ...
user_id: int content_type: models.ForeignKey[ContentType] = ...
action_time: datetime.datetime = ... object_id: models.TextField = ...
user: Any = ... object_repr: models.CharField = ...
content_type: Any = ... action_flag: models.PositiveSmallIntegerField = ...
object_id: str = ... change_message: models.TextField = ...
object_repr: str = ... objects: LogEntryManager = ...
action_flag: int = ...
change_message: str = ...
objects: Any = ...
def is_addition(self) -> bool: ... def is_addition(self) -> bool: ...
def is_change(self) -> bool: ... def is_change(self) -> bool: ...
def is_deletion(self) -> bool: ... def is_deletion(self) -> bool: ...

View File

@@ -1,46 +1,33 @@
import datetime
from typing import Any, List, Optional, Set, Tuple, Type, Union from typing import Any, List, Optional, Set, Tuple, Type, Union
from django.contrib.auth.base_user import AbstractBaseUser as AbstractBaseUser, BaseUserManager as BaseUserManager from django.contrib.auth.base_user import AbstractBaseUser as AbstractBaseUser, BaseUserManager as BaseUserManager
from django.contrib.contenttypes.models import ContentType
from django.db.models.manager import EmptyManager from django.db.models.manager import EmptyManager
from django.contrib.auth.validators import UnicodeUsernameValidator
from django.db import models from django.db import models
def update_last_login(sender: Type[AbstractBaseUser], user: AbstractBaseUser, **kwargs: Any) -> None: ... def update_last_login(sender: Type[AbstractBaseUser], user: AbstractBaseUser, **kwargs: Any) -> None: ...
class PermissionManager(models.Manager): class PermissionManager(models.Manager):
creation_counter: int
model: None
name: None
use_in_migrations: bool = ...
def get_by_natural_key(self, codename: str, app_label: str, model: str) -> Permission: ... def get_by_natural_key(self, codename: str, app_label: str, model: str) -> Permission: ...
class Permission(models.Model): class Permission(models.Model):
content_type_id: int content_type_id: int
id: int name: models.CharField = ...
name: str = ... content_type: models.ForeignKey[ContentType] = ...
content_type: Any = ...
codename: str = ... codename: str = ...
def natural_key(self) -> Tuple[str, str, str]: ... def natural_key(self) -> Tuple[str, str, str]: ...
class GroupManager(models.Manager): class GroupManager(models.Manager):
creation_counter: int
model: None
name: None
use_in_migrations: bool = ...
def get_by_natural_key(self, name: str) -> Group: ... def get_by_natural_key(self, name: str) -> Group: ...
class Group(models.Model): class Group(models.Model):
id: None name: models.CharField = ...
name: str = ... permissions: models.ManyToManyField[Permission] = ...
permissions: Any = ...
def natural_key(self): ... def natural_key(self): ...
class UserManager(BaseUserManager): class UserManager(BaseUserManager):
creation_counter: int
model: None
name: None
use_in_migrations: bool = ...
def create_user( def create_user(
self, username: str, email: Optional[str] = ..., password: Optional[str] = ..., **extra_fields: Any self, username: str, email: Optional[str] = ..., password: Optional[str] = ..., **extra_fields: Any
) -> AbstractUser: ... ) -> AbstractUser: ...
@@ -49,9 +36,9 @@ class UserManager(BaseUserManager):
) -> AbstractBaseUser: ... ) -> AbstractBaseUser: ...
class PermissionsMixin(models.Model): class PermissionsMixin(models.Model):
is_superuser: Any = ... is_superuser: models.BooleanField = ...
groups: Any = ... groups: models.ManyToManyField[Group] = ...
user_permissions: Any = ... user_permissions: models.ManyToManyField[Permission] = ...
def get_group_permissions(self, obj: None = ...) -> Set[str]: ... def get_group_permissions(self, obj: None = ...) -> Set[str]: ...
def get_all_permissions(self, obj: Optional[str] = ...) -> Set[str]: ... def get_all_permissions(self, obj: Optional[str] = ...) -> Set[str]: ...
def has_perm(self, perm: Union[Tuple[str, Any], str], obj: Optional[str] = ...) -> bool: ... def has_perm(self, perm: Union[Tuple[str, Any], str], obj: Optional[str] = ...) -> bool: ...
@@ -59,14 +46,13 @@ class PermissionsMixin(models.Model):
def has_module_perms(self, app_label: str) -> bool: ... def has_module_perms(self, app_label: str) -> bool: ...
class AbstractUser(AbstractBaseUser, PermissionsMixin): # type: ignore class AbstractUser(AbstractBaseUser, PermissionsMixin): # type: ignore
is_superuser: bool username_validator: UnicodeUsernameValidator = ...
username_validator: Any = ... username: models.CharField = ...
username: str = ... first_name: models.CharField = ...
first_name: str = ... last_name: models.CharField = ...
last_name: str = ... email: models.EmailField = ...
email: str = ... is_staff: models.BooleanField = ...
is_staff: bool = ... date_joined: models.DateTimeField = ...
date_joined: datetime.datetime = ...
EMAIL_FIELD: str = ... EMAIL_FIELD: str = ...
USERNAME_FIELD: str = ... USERNAME_FIELD: str = ...
def clean(self) -> None: ... def clean(self) -> None: ...

View File

@@ -8,3 +8,4 @@ from .ranges import (
DateRangeField as DateRangeField, DateRangeField as DateRangeField,
DateTimeRangeField as DateTimeRangeField, DateTimeRangeField as DateTimeRangeField,
) )
from .hstore import HStoreField as HStoreField

View File

@@ -13,14 +13,8 @@ class ArrayField(CheckFieldDefaultMixin, Field, Generic[_T]):
default_validators: Any = ... default_validators: Any = ...
from_db_value: Any = ... from_db_value: Any = ...
def __init__(self, base_field: _T, size: Optional[int] = ..., **kwargs: Any) -> None: ... def __init__(self, base_field: _T, size: Optional[int] = ..., **kwargs: Any) -> None: ...
def check(self, **kwargs: Any) -> List[Any]: ...
@property @property
def description(self): ... def description(self): ...
def get_db_prep_value(self, value: Any, connection: Any, prepared: bool = ...): ...
def to_python(self, value: Any): ...
def value_to_string(self, obj: Any): ...
def get_transform(self, name: Any): ... def get_transform(self, name: Any): ...
def validate(self, value: Any, model_instance: Any) -> None: ... def __set__(self, instance, value: Sequence[_T]) -> None: ...
def run_validators(self, value: Any) -> None: ...
def __set__(self, instance, value: Sequence[_T]): ...
def __get__(self, instance, owner) -> List[_T]: ... def __get__(self, instance, owner) -> List[_T]: ...

View File

@@ -0,0 +1,17 @@
from typing import Any
from django.db.models import Field, Transform
from .mixins import CheckFieldDefaultMixin
class HStoreField(CheckFieldDefaultMixin, Field):
def get_transform(self, name) -> Any: ...
class KeyTransform(Transform):
def __init__(self, key_name: str, *args: Any, **kwargs: Any): ...
class KeyTransformFactory:
def __init__(self, key_name: str): ...
def __call__(self, *args, **kwargs) -> KeyTransform: ...
class KeysTransform(Transform): ...
class ValuesTransform(Transform): ...

View File

@@ -1,6 +1,6 @@
import uuid import uuid
from datetime import date, time, datetime, timedelta from datetime import date, time, datetime, timedelta
from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar
import decimal import decimal
from django.db.models import Model from django.db.models import Model

View File

@@ -95,13 +95,14 @@ class QuerySet(Iterable[_T], Sized):
def raw( def raw(
self, raw_query: str, params: Any = ..., translations: Optional[Dict[str, str]] = ..., using: None = ... self, raw_query: str, params: Any = ..., translations: Optional[Dict[str, str]] = ..., using: None = ...
) -> RawQuerySet: ... ) -> RawQuerySet: ...
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> ValuesIterable: ... def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet: ...
@overload def values_list(self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...) -> QuerySet: ...
def values_list(self, *fields: Union[str, Combinable], named: Literal[True]) -> NamedValuesListIterable: ... # @overload
@overload # def values_list(self, *fields: Union[str, Combinable], named: Literal[True]) -> NamedValuesListIterable: ...
def values_list(self, *fields: Union[str, Combinable], flat: Literal[True]) -> FlatValuesListIterable: ... # @overload
@overload # def values_list(self, *fields: Union[str, Combinable], flat: Literal[True]) -> FlatValuesListIterable: ...
def values_list(self, *fields: Union[str, Combinable]) -> ValuesListIterable: ... # @overload
# def values_list(self, *fields: Union[str, Combinable]) -> ValuesListIterable: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet: ... def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet: ...
def datetimes(self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...) -> QuerySet: ... def datetimes(self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...) -> QuerySet: ...
def none(self) -> QuerySet[_T]: ... def none(self) -> QuerySet[_T]: ...

View File

@@ -1,7 +1,7 @@
from datetime import time from datetime import time
from decimal import Decimal from decimal import Decimal
from itertools import chain from itertools import chain
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, Iterable from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, Iterable, Sequence
from django.contrib.admin.options import BaseModelAdmin from django.contrib.admin.options import BaseModelAdmin
from django.core.files.base import File from django.core.files.base import File
@@ -114,6 +114,8 @@ class CheckboxInput(Input):
check_test: Callable = ... check_test: Callable = ...
def __init__(self, attrs: Optional[Dict[str, str]] = ..., check_test: Optional[Callable] = ...) -> None: ... def __init__(self, attrs: Optional[Dict[str, str]] = ..., check_test: Optional[Callable] = ...) -> None: ...
_OptAttrs = Dict[str, Any]
class ChoiceWidget(Widget): class ChoiceWidget(Widget):
allow_multiple_selected: bool = ... allow_multiple_selected: bool = ...
input_type: Optional[str] = ... input_type: Optional[str] = ...
@@ -123,17 +125,9 @@ class ChoiceWidget(Widget):
checked_attribute: Any = ... checked_attribute: Any = ...
option_inherits_attrs: bool = ... option_inherits_attrs: bool = ...
choices: List[List[Union[int, str]]] = ... choices: List[List[Union[int, str]]] = ...
def __init__( def __init__(self, attrs: Optional[_OptAttrs] = ..., choices: Sequence[Tuple[Any, Any]] = ...) -> None: ...
self, def options(self, name: str, value: List[str], attrs: Optional[_OptAttrs] = ...) -> None: ...
attrs: Optional[Dict[str, Union[bool, str]]] = ..., def optgroups(self, name: str, value: List[str], attrs: Optional[_OptAttrs] = ...) -> Any: ...
choices: Union[
Iterator[Any], List[List[Union[int, str]]], List[Tuple[Union[time, int], int]], List[int], Tuple
] = ...,
) -> None: ...
def options(self, name: str, value: List[str], attrs: Dict[str, Union[bool, str]] = ...) -> None: ...
def optgroups(
self, name: str, value: List[str], attrs: Optional[Dict[str, Union[bool, str]]] = ...
) -> List[Tuple[Optional[str], List[Dict[str, Union[Dict[str, Union[bool, str]], time, int, str]]], int]]: ...
def create_option( def create_option(
self, self,
name: str, name: str,
@@ -142,8 +136,8 @@ class ChoiceWidget(Widget):
selected: Union[Set[str], bool], selected: Union[Set[str], bool],
index: int, index: int,
subindex: Optional[int] = ..., subindex: Optional[int] = ...,
attrs: Optional[Dict[str, Union[bool, str]]] = ..., attrs: Optional[_OptAttrs] = ...,
) -> Dict[str, Union[Dict[str, Union[bool, str]], Dict[str, bool], Set[str], time, int, str]]: ... ) -> Dict[str, Any]: ...
def id_for_label(self, id_: str, index: str = ...) -> str: ... def id_for_label(self, id_: str, index: str = ...) -> str: ...
class Select(ChoiceWidget): class Select(ChoiceWidget):
@@ -171,11 +165,7 @@ class CheckboxSelectMultiple(ChoiceWidget):
class MultiWidget(Widget): class MultiWidget(Widget):
template_name: str = ... template_name: str = ...
widgets: List[Widget] = ... widgets: List[Widget] = ...
def __init__( def __init__(self, widgets: Sequence[Widget], attrs: Optional[_OptAttrs] = ...) -> None: ...
self,
widgets: Union[List[Type[DateTimeBaseInput]], Tuple[Union[Type[TextInput], Input]]],
attrs: Optional[Dict[str, str]] = ...,
) -> None: ...
@property @property
def is_hidden(self) -> bool: ... def is_hidden(self) -> bool: ...
def decompress(self, value: Any) -> Optional[Any]: ... def decompress(self, value: Any) -> Optional[Any]: ...
@@ -218,7 +208,7 @@ class SelectDateWidget(Widget):
day_none_value: Any = ... day_none_value: Any = ...
def __init__( def __init__(
self, self,
attrs: None = ..., attrs: Optional[_OptAttrs] = ...,
years: Optional[Union[Tuple[Union[int, str]], range]] = ..., years: Optional[Union[Tuple[Union[int, str]], range]] = ...,
months: None = ..., months: None = ...,
empty_label: Optional[Union[Tuple[str, str], str]] = ..., empty_label: Optional[Union[Tuple[str, str], str]] = ...,

View File

@@ -1,9 +1,13 @@
import typing import typing
from typing import Dict, Optional from typing import Dict, Optional
from mypy.nodes import MypyFile, TypeInfo, ImportedName, SymbolNode from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field'
GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey' FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField' ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField' MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField'
@@ -78,3 +82,116 @@ def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile])
if sym is None: if sym is None:
return None return None
return sym.node return sym.node
def parse_bool(expr: Expression) -> Optional[bool]:
if isinstance(expr, NameExpr):
if expr.fullname == 'builtins.True':
return True
if expr.fullname == 'builtins.False':
return False
return None
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
return Instance(instance.type, args=new_typevars)
def fill_typevars_with_any(instance: Instance) -> Type:
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
def extract_typevar_value(tp: Instance, typevar_name: str):
return tp.args[tp.type.type_vars.index(typevar_name)]
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if tp.type.has_base(FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
typevar_values: typing.List[Type] = []
for typevar_arg in set_value_type.args:
if isinstance(typevar_arg, TypeVarType):
typevar_values.append(extract_typevar_value(tp, typevar_arg.name))
# if there are typevars, extract from
set_value_type = reparametrize_with(set_value_type, typevar_values)
return set_value_type
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
return get_method.type.ret_type
# GenericForeignKey
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
# only primary keys defined in current class for now
for sym in model.names.values():
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
tp = sym.node.type
if tp.type.metadata.get('django', {}).get('defined_as_primary_key'):
field_type = extract_field_setter_type(tp)
return field_type
return None
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
expected_types: Dict[str, Type] = {}
for base in model.mro:
for name, sym in base.names.items():
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
tp = sym.node.type
field_type = extract_field_setter_type(tp)
if tp.type.fullname() == FOREIGN_KEY_FULLNAME:
ref_to_model = tp.args[0]
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME):
primary_key_type = extract_primary_key_type(ref_to_model.type)
if not primary_key_type:
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types[name + '_id'] = primary_key_type
if field_type:
expected_types[name] = field_type
primary_key_type = extract_primary_key_type(model)
if not primary_key_type:
# no explicit primary key, set pk to Any and add id
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
expected_types['pk'] = primary_key_type
return expected_types
def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]:
"""Return the expression for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
args = ctx.args[idx]
if len(args) != 1:
# Either an error or no value passed.
return None
return args[0]
def get_argument_type_by_name(ctx: FunctionContext, name: str) -> Optional[Type]:
"""Return the type for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
arg_types = ctx.arg_types[idx]
if len(arg_types) != 1:
# Either an error or no value passed.
return None
return arg_types[0]

View File

@@ -8,6 +8,7 @@ from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import Instance, Type from mypy.types import Instance, Type
from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.helpers import parse_bool
from mypy_django_plugin.plugins.fields import determine_type_of_array_field from mypy_django_plugin.plugins.fields import determine_type_of_array_field
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
from mypy_django_plugin.plugins.models import process_model_class from mypy_django_plugin.plugins.models import process_model_class
@@ -54,6 +55,40 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
return ret return ret
def redefine_model_init(ctx: FunctionContext) -> Type:
assert isinstance(ctx.default_return_type, Instance)
api = cast(TypeChecker, ctx.api)
model: TypeInfo = ctx.default_return_type.type
expected_types = helpers.extract_expected_types(ctx, model)
for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]):
if actual_name is None:
# We can't check kwargs reliably.
continue
if actual_name not in expected_types:
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model.name()),
ctx.context)
continue
api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model.name()),
'got', 'expected')
return ctx.default_return_type
def set_primary_key_marking(ctx: FunctionContext) -> Type:
primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key')
if primary_key_arg:
is_primary_key = parse_bool(primary_key_arg)
if is_primary_key:
info = ctx.default_return_type.type
info.metadata.setdefault('django', {})['defined_as_primary_key'] = True
return ctx.default_return_type
class DjangoPlugin(Plugin): class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None: def __init__(self, options: Options) -> None:
super().__init__(options) super().__init__(options)
@@ -105,6 +140,13 @@ class DjangoPlugin(Plugin):
if fullname in manager_bases: if fullname in manager_bases:
return determine_proper_manager_type return determine_proper_manager_type
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
if sym.node.has_base(helpers.FIELD_FULLNAME):
return set_primary_key_marking
if sym.node.metadata.get('django', {}).get('generated_init'):
return redefine_model_init
def get_method_hook(self, fullname: str def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]: ) -> Optional[Callable[[MethodContext], Type]]:
if fullname in {'django.apps.registry.Apps.get_model', if fullname in {'django.apps.registry.Apps.get_model',

View File

@@ -1,13 +1,12 @@
from mypy.plugin import FunctionContext from mypy.plugin import FunctionContext
from mypy.types import Type, Instance from mypy.types import Type, Instance
from mypy_django_plugin import helpers
def determine_type_of_array_field(ctx: FunctionContext) -> Type: def determine_type_of_array_field(ctx: FunctionContext) -> Type:
if 'base_field' not in ctx.callee_arg_names: base_field_arg_type = helpers.get_argument_type_by_name(ctx, 'base_field')
return ctx.default_return_type if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
base_field_arg_type = ctx.arg_types[ctx.callee_arg_names.index('base_field')][0]
if not isinstance(base_field_arg_type, Instance):
return ctx.default_return_type return ctx.default_return_type
get_method = base_field_arg_type.type.get_method('__get__') get_method = base_field_arg_type.type.get_method('__get__')

View File

@@ -2,11 +2,12 @@ from abc import ABCMeta, abstractmethod
from typing import Dict, Iterator, Optional, Tuple, cast from typing import Dict, Iterator, Optional, Tuple, cast
import dataclasses import dataclasses
from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, \ from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, \
StrExpr, SymbolTableNode, TypeInfo, Var MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, Argument, ARG_STAR2
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzerPass2 from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance from mypy.types import Instance, AnyType, TypeOfAny, NoneTyp
from mypy_django_plugin import helpers from mypy_django_plugin import helpers
@@ -199,16 +200,27 @@ def extract_ref_to_fullname(rvalue_expr: CallExpr,
return None return None
def add_dummy_init_method(ctx: ClassDefContext) -> None:
any = AnyType(TypeOfAny.special_form)
var = Var('kwargs', any)
kw_arg = Argument(variable=var, type_annotation=any, initializer=None, kind=ARG_STAR2)
add_method(ctx, '__init__', [kw_arg], NoneTyp())
# mark as model class
ctx.cls.info.metadata.setdefault('django', {})['generated_init'] = True
def process_model_class(ctx: ClassDefContext) -> None: def process_model_class(ctx: ClassDefContext) -> None:
initializers = [ initializers = [
InjectAnyAsBaseForNestedMeta, InjectAnyAsBaseForNestedMeta,
AddDefaultObjectsManager, AddDefaultObjectsManager,
AddIdAttributeIfPrimaryKeyTrueIsNotSet, AddIdAttributeIfPrimaryKeyTrueIsNotSet,
SetIdAttrsForRelatedFields, SetIdAttrsForRelatedFields,
AddRelatedManagers AddRelatedManagers,
] ]
for initializer_cls in initializers: for initializer_cls in initializers:
initializer_cls.from_ctx(ctx).run() initializer_cls.from_ctx(ctx).run()
add_dummy_init_method(ctx)
# allow unspecified attributes for now # allow unspecified attributes for now
ctx.cls.info.fallback_to_any = True ctx.cls.info.fallback_to_any = True

View File

@@ -1,20 +1,12 @@
import typing
from typing import Optional, cast from typing import Optional, cast
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.nodes import StrExpr, TypeInfo from mypy.nodes import StrExpr, TypeInfo
from mypy.plugin import FunctionContext from mypy.plugin import FunctionContext
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny from mypy.types import CallableType, Instance, Type
from mypy_django_plugin import helpers from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import fill_typevars_with_any, reparametrize_with
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
return Instance(instance.type, args=new_typevars)
def fill_typevars_with_any(instance: Instance) -> Type:
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:

View File

@@ -29,6 +29,7 @@ IGNORED_ERRORS = {
'Invalid value for a to= parameter', 'Invalid value for a to= parameter',
'already defined (possibly by an import)', 'already defined (possibly by an import)',
'Cannot assign to a type', 'Cannot assign to a type',
re.compile(r'Cannot assign to class variable "[a-z_]+" via instance'),
# forms <-> models plugin support # forms <-> models plugin support
'"Model" has no attribute', '"Model" has no attribute',
re.compile(r'Cannot determine type of \'(objects|stuff)\''), re.compile(r'Cannot determine type of \'(objects|stuff)\''),
@@ -73,7 +74,8 @@ IGNORED_ERRORS = {
'admin_views': [ 'admin_views': [
'Argument 1 to "FileWrapper" has incompatible type "StringIO"; expected "IO[bytes]"', 'Argument 1 to "FileWrapper" has incompatible type "StringIO"; expected "IO[bytes]"',
'Incompatible types in assignment', 'Incompatible types in assignment',
'"object" not callable' '"object" not callable',
'Incompatible type for "pk" of "Collector" (got "int", expected "str")'
], ],
'aggregation': [ 'aggregation': [
'Incompatible types in assignment (expression has type "QuerySet[Any]", variable has type "List[Any]")', 'Incompatible types in assignment (expression has type "QuerySet[Any]", variable has type "List[Any]")',

View File

@@ -0,0 +1,106 @@
[CASE arguments_to_init_unexpected_attributes]
from django.db import models
class MyUser(models.Model):
pass
user = MyUser(name=1, age=12)
[out]
main:5: error: Unexpected attribute "name" for model "MyUser"
main:5: error: Unexpected attribute "age" for model "MyUser"
[CASE arguments_to_init_from_class_incompatible_type]
from django.db import models
class MyUser(models.Model):
name = models.CharField(max_length=100)
age = models.IntegerField()
user = MyUser(name=1, age=12)
[out]
main:6: error: Incompatible type for "name" of "MyUser" (got "int", expected "str")
[CASE arguments_to_init_combined_from_base_classes]
from django.db import models
class BaseUser(models.Model):
name = models.CharField(max_length=100)
age = models.IntegerField()
class ChildUser(BaseUser):
lastname = models.CharField(max_length=100)
user = ChildUser(name='Max', age=12, lastname='Lastname')
[out]
[CASE fields_from_abstract_user_propagate_to_init]
from django.contrib.auth.models import AbstractUser
class MyUser(AbstractUser):
pass
user = MyUser(username='maxim', password='password', first_name='Max', last_name='MaxMax')
[out]
[CASE generic_foreign_key_field_no_typechecking]
from django.db import models
from django.contrib.contenttypes.fields import GenericForeignKey
class MyUser(models.Model):
content_object = GenericForeignKey()
user = MyUser(content_object=1)
[out]
[CASE pk_refers_to_primary_key_and_could_be_passed_to_init]
from django.db import models
class MyUser1(models.Model):
mypk = models.CharField(primary_key=True)
class MyUser2(models.Model):
pass
user2 = MyUser1(pk='hello')
user3= MyUser2(pk=1)
[out]
[CASE typechecking_of_pk]
from django.db import models
class MyUser1(models.Model):
mypk = models.CharField(primary_key=True)
user = MyUser1(pk=1) # E: Incompatible type for "pk" of "MyUser1" (got "int", expected "str")
[out]
[CASE can_set_foreign_key_by_its_primary_key]
from django.db import models
class Publisher(models.Model):
pass
class PublisherWithCharPK(models.Model):
id = models.CharField(max_length=100, primary_key=True)
class Book(models.Model):
publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE)
publisher_with_char_pk = models.ForeignKey(PublisherWithCharPK, on_delete=models.CASCADE)
Book(publisher_id=1, publisher_with_char_pk_id='hello')
Book(publisher_id=1, publisher_with_char_pk_id=1) # E: Incompatible type for "publisher_with_char_pk_id" of "Book" (got "int", expected "str")
[out]
[CASE setting_value_to_an_array_of_ints]
from typing import List, Tuple
from django.db import models
from django.contrib.postgres.fields import ArrayField
class MyModel(models.Model):
array = ArrayField(base_field=models.IntegerField())
array_val: Tuple[int, ...] = (1,)
MyModel(array=array_val)
array_val2: List[int] = [1]
MyModel(array=array_val2)
array_val3: List[str] = ['hello']
MyModel(array=array_val3) # E: Incompatible type for "array" of "MyModel" (got "List[str]", expected "Sequence[int]")
[out]
[CASE if_no_explicit_primary_key_id_can_be_passed]
from django.db import models
class MyModel(models.Model):
pass
MyModel(id=1)
[out]