mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 12:14:28 +08:00
add Model.__init__ typechecking
This commit is contained in:
@@ -1,20 +1,17 @@
|
||||
import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from django.db import models
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db.models.base import Model
|
||||
|
||||
from django.db import models
|
||||
|
||||
ADDITION: int
|
||||
CHANGE: int
|
||||
DELETION: int
|
||||
ACTION_FLAG_CHOICES: Any
|
||||
|
||||
class LogEntryManager(models.Manager):
|
||||
creation_counter: int
|
||||
model: None
|
||||
name: None
|
||||
use_in_migrations: bool = ...
|
||||
def log_action(
|
||||
self,
|
||||
user_id: int,
|
||||
@@ -22,23 +19,18 @@ class LogEntryManager(models.Manager):
|
||||
object_id: Union[int, str, UUID],
|
||||
object_repr: str,
|
||||
action_flag: int,
|
||||
change_message: Union[
|
||||
Dict[str, Dict[str, List[str]]], List[Dict[str, Dict[str, Union[List[str], str]]]], str
|
||||
] = ...,
|
||||
change_message: Any = ...,
|
||||
) -> LogEntry: ...
|
||||
|
||||
class LogEntry(models.Model):
|
||||
content_type_id: int
|
||||
id: None
|
||||
user_id: int
|
||||
action_time: datetime.datetime = ...
|
||||
user: Any = ...
|
||||
content_type: Any = ...
|
||||
object_id: str = ...
|
||||
object_repr: str = ...
|
||||
action_flag: int = ...
|
||||
change_message: str = ...
|
||||
objects: Any = ...
|
||||
action_time: models.DateTimeField = ...
|
||||
user: models.ForeignKey = ...
|
||||
content_type: models.ForeignKey[ContentType] = ...
|
||||
object_id: models.TextField = ...
|
||||
object_repr: models.CharField = ...
|
||||
action_flag: models.PositiveSmallIntegerField = ...
|
||||
change_message: models.TextField = ...
|
||||
objects: LogEntryManager = ...
|
||||
def is_addition(self) -> bool: ...
|
||||
def is_change(self) -> bool: ...
|
||||
def is_deletion(self) -> bool: ...
|
||||
|
||||
@@ -1,46 +1,33 @@
|
||||
import datetime
|
||||
from typing import Any, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from django.contrib.auth.base_user import AbstractBaseUser as AbstractBaseUser, BaseUserManager as BaseUserManager
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db.models.manager import EmptyManager
|
||||
|
||||
from django.contrib.auth.validators import UnicodeUsernameValidator
|
||||
from django.db import models
|
||||
|
||||
def update_last_login(sender: Type[AbstractBaseUser], user: AbstractBaseUser, **kwargs: Any) -> None: ...
|
||||
|
||||
class PermissionManager(models.Manager):
|
||||
creation_counter: int
|
||||
model: None
|
||||
name: None
|
||||
use_in_migrations: bool = ...
|
||||
def get_by_natural_key(self, codename: str, app_label: str, model: str) -> Permission: ...
|
||||
|
||||
class Permission(models.Model):
|
||||
content_type_id: int
|
||||
id: int
|
||||
name: str = ...
|
||||
content_type: Any = ...
|
||||
name: models.CharField = ...
|
||||
content_type: models.ForeignKey[ContentType] = ...
|
||||
codename: str = ...
|
||||
def natural_key(self) -> Tuple[str, str, str]: ...
|
||||
|
||||
class GroupManager(models.Manager):
|
||||
creation_counter: int
|
||||
model: None
|
||||
name: None
|
||||
use_in_migrations: bool = ...
|
||||
def get_by_natural_key(self, name: str) -> Group: ...
|
||||
|
||||
class Group(models.Model):
|
||||
id: None
|
||||
name: str = ...
|
||||
permissions: Any = ...
|
||||
name: models.CharField = ...
|
||||
permissions: models.ManyToManyField[Permission] = ...
|
||||
def natural_key(self): ...
|
||||
|
||||
class UserManager(BaseUserManager):
|
||||
creation_counter: int
|
||||
model: None
|
||||
name: None
|
||||
use_in_migrations: bool = ...
|
||||
def create_user(
|
||||
self, username: str, email: Optional[str] = ..., password: Optional[str] = ..., **extra_fields: Any
|
||||
) -> AbstractUser: ...
|
||||
@@ -49,9 +36,9 @@ class UserManager(BaseUserManager):
|
||||
) -> AbstractBaseUser: ...
|
||||
|
||||
class PermissionsMixin(models.Model):
|
||||
is_superuser: Any = ...
|
||||
groups: Any = ...
|
||||
user_permissions: Any = ...
|
||||
is_superuser: models.BooleanField = ...
|
||||
groups: models.ManyToManyField[Group] = ...
|
||||
user_permissions: models.ManyToManyField[Permission] = ...
|
||||
def get_group_permissions(self, obj: None = ...) -> Set[str]: ...
|
||||
def get_all_permissions(self, obj: Optional[str] = ...) -> Set[str]: ...
|
||||
def has_perm(self, perm: Union[Tuple[str, Any], str], obj: Optional[str] = ...) -> bool: ...
|
||||
@@ -59,14 +46,13 @@ class PermissionsMixin(models.Model):
|
||||
def has_module_perms(self, app_label: str) -> bool: ...
|
||||
|
||||
class AbstractUser(AbstractBaseUser, PermissionsMixin): # type: ignore
|
||||
is_superuser: bool
|
||||
username_validator: Any = ...
|
||||
username: str = ...
|
||||
first_name: str = ...
|
||||
last_name: str = ...
|
||||
email: str = ...
|
||||
is_staff: bool = ...
|
||||
date_joined: datetime.datetime = ...
|
||||
username_validator: UnicodeUsernameValidator = ...
|
||||
username: models.CharField = ...
|
||||
first_name: models.CharField = ...
|
||||
last_name: models.CharField = ...
|
||||
email: models.EmailField = ...
|
||||
is_staff: models.BooleanField = ...
|
||||
date_joined: models.DateTimeField = ...
|
||||
EMAIL_FIELD: str = ...
|
||||
USERNAME_FIELD: str = ...
|
||||
def clean(self) -> None: ...
|
||||
|
||||
@@ -8,3 +8,4 @@ from .ranges import (
|
||||
DateRangeField as DateRangeField,
|
||||
DateTimeRangeField as DateTimeRangeField,
|
||||
)
|
||||
from .hstore import HStoreField as HStoreField
|
||||
|
||||
@@ -13,14 +13,8 @@ class ArrayField(CheckFieldDefaultMixin, Field, Generic[_T]):
|
||||
default_validators: Any = ...
|
||||
from_db_value: Any = ...
|
||||
def __init__(self, base_field: _T, size: Optional[int] = ..., **kwargs: Any) -> None: ...
|
||||
def check(self, **kwargs: Any) -> List[Any]: ...
|
||||
@property
|
||||
def description(self): ...
|
||||
def get_db_prep_value(self, value: Any, connection: Any, prepared: bool = ...): ...
|
||||
def to_python(self, value: Any): ...
|
||||
def value_to_string(self, obj: Any): ...
|
||||
def get_transform(self, name: Any): ...
|
||||
def validate(self, value: Any, model_instance: Any) -> None: ...
|
||||
def run_validators(self, value: Any) -> None: ...
|
||||
def __set__(self, instance, value: Sequence[_T]): ...
|
||||
def __set__(self, instance, value: Sequence[_T]) -> None: ...
|
||||
def __get__(self, instance, owner) -> List[_T]: ...
|
||||
|
||||
17
django-stubs/contrib/postgres/fields/hstore.pyi
Normal file
17
django-stubs/contrib/postgres/fields/hstore.pyi
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Any
|
||||
|
||||
from django.db.models import Field, Transform
|
||||
from .mixins import CheckFieldDefaultMixin
|
||||
|
||||
class HStoreField(CheckFieldDefaultMixin, Field):
|
||||
def get_transform(self, name) -> Any: ...
|
||||
|
||||
class KeyTransform(Transform):
|
||||
def __init__(self, key_name: str, *args: Any, **kwargs: Any): ...
|
||||
|
||||
class KeyTransformFactory:
|
||||
def __init__(self, key_name: str): ...
|
||||
def __call__(self, *args, **kwargs) -> KeyTransform: ...
|
||||
|
||||
class KeysTransform(Transform): ...
|
||||
class ValuesTransform(Transform): ...
|
||||
@@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from datetime import date, time, datetime, timedelta
|
||||
from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type
|
||||
from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar
|
||||
import decimal
|
||||
|
||||
from django.db.models import Model
|
||||
|
||||
@@ -95,13 +95,14 @@ class QuerySet(Iterable[_T], Sized):
|
||||
def raw(
|
||||
self, raw_query: str, params: Any = ..., translations: Optional[Dict[str, str]] = ..., using: None = ...
|
||||
) -> RawQuerySet: ...
|
||||
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> ValuesIterable: ...
|
||||
@overload
|
||||
def values_list(self, *fields: Union[str, Combinable], named: Literal[True]) -> NamedValuesListIterable: ...
|
||||
@overload
|
||||
def values_list(self, *fields: Union[str, Combinable], flat: Literal[True]) -> FlatValuesListIterable: ...
|
||||
@overload
|
||||
def values_list(self, *fields: Union[str, Combinable]) -> ValuesListIterable: ...
|
||||
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet: ...
|
||||
def values_list(self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...) -> QuerySet: ...
|
||||
# @overload
|
||||
# def values_list(self, *fields: Union[str, Combinable], named: Literal[True]) -> NamedValuesListIterable: ...
|
||||
# @overload
|
||||
# def values_list(self, *fields: Union[str, Combinable], flat: Literal[True]) -> FlatValuesListIterable: ...
|
||||
# @overload
|
||||
# def values_list(self, *fields: Union[str, Combinable]) -> ValuesListIterable: ...
|
||||
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet: ...
|
||||
def datetimes(self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...) -> QuerySet: ...
|
||||
def none(self) -> QuerySet[_T]: ...
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import time
|
||||
from decimal import Decimal
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, Iterable
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union, Iterable, Sequence
|
||||
|
||||
from django.contrib.admin.options import BaseModelAdmin
|
||||
from django.core.files.base import File
|
||||
@@ -114,6 +114,8 @@ class CheckboxInput(Input):
|
||||
check_test: Callable = ...
|
||||
def __init__(self, attrs: Optional[Dict[str, str]] = ..., check_test: Optional[Callable] = ...) -> None: ...
|
||||
|
||||
_OptAttrs = Dict[str, Any]
|
||||
|
||||
class ChoiceWidget(Widget):
|
||||
allow_multiple_selected: bool = ...
|
||||
input_type: Optional[str] = ...
|
||||
@@ -123,17 +125,9 @@ class ChoiceWidget(Widget):
|
||||
checked_attribute: Any = ...
|
||||
option_inherits_attrs: bool = ...
|
||||
choices: List[List[Union[int, str]]] = ...
|
||||
def __init__(
|
||||
self,
|
||||
attrs: Optional[Dict[str, Union[bool, str]]] = ...,
|
||||
choices: Union[
|
||||
Iterator[Any], List[List[Union[int, str]]], List[Tuple[Union[time, int], int]], List[int], Tuple
|
||||
] = ...,
|
||||
) -> None: ...
|
||||
def options(self, name: str, value: List[str], attrs: Dict[str, Union[bool, str]] = ...) -> None: ...
|
||||
def optgroups(
|
||||
self, name: str, value: List[str], attrs: Optional[Dict[str, Union[bool, str]]] = ...
|
||||
) -> List[Tuple[Optional[str], List[Dict[str, Union[Dict[str, Union[bool, str]], time, int, str]]], int]]: ...
|
||||
def __init__(self, attrs: Optional[_OptAttrs] = ..., choices: Sequence[Tuple[Any, Any]] = ...) -> None: ...
|
||||
def options(self, name: str, value: List[str], attrs: Optional[_OptAttrs] = ...) -> None: ...
|
||||
def optgroups(self, name: str, value: List[str], attrs: Optional[_OptAttrs] = ...) -> Any: ...
|
||||
def create_option(
|
||||
self,
|
||||
name: str,
|
||||
@@ -142,8 +136,8 @@ class ChoiceWidget(Widget):
|
||||
selected: Union[Set[str], bool],
|
||||
index: int,
|
||||
subindex: Optional[int] = ...,
|
||||
attrs: Optional[Dict[str, Union[bool, str]]] = ...,
|
||||
) -> Dict[str, Union[Dict[str, Union[bool, str]], Dict[str, bool], Set[str], time, int, str]]: ...
|
||||
attrs: Optional[_OptAttrs] = ...,
|
||||
) -> Dict[str, Any]: ...
|
||||
def id_for_label(self, id_: str, index: str = ...) -> str: ...
|
||||
|
||||
class Select(ChoiceWidget):
|
||||
@@ -171,11 +165,7 @@ class CheckboxSelectMultiple(ChoiceWidget):
|
||||
class MultiWidget(Widget):
|
||||
template_name: str = ...
|
||||
widgets: List[Widget] = ...
|
||||
def __init__(
|
||||
self,
|
||||
widgets: Union[List[Type[DateTimeBaseInput]], Tuple[Union[Type[TextInput], Input]]],
|
||||
attrs: Optional[Dict[str, str]] = ...,
|
||||
) -> None: ...
|
||||
def __init__(self, widgets: Sequence[Widget], attrs: Optional[_OptAttrs] = ...) -> None: ...
|
||||
@property
|
||||
def is_hidden(self) -> bool: ...
|
||||
def decompress(self, value: Any) -> Optional[Any]: ...
|
||||
@@ -218,7 +208,7 @@ class SelectDateWidget(Widget):
|
||||
day_none_value: Any = ...
|
||||
def __init__(
|
||||
self,
|
||||
attrs: None = ...,
|
||||
attrs: Optional[_OptAttrs] = ...,
|
||||
years: Optional[Union[Tuple[Union[int, str]], range]] = ...,
|
||||
months: None = ...,
|
||||
empty_label: Optional[Union[Tuple[str, str], str]] = ...,
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import typing
|
||||
from typing import Dict, Optional
|
||||
|
||||
from mypy.nodes import MypyFile, TypeInfo, ImportedName, SymbolNode
|
||||
from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType
|
||||
|
||||
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
||||
FIELD_FULLNAME = 'django.db.models.fields.Field'
|
||||
GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey'
|
||||
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
|
||||
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
|
||||
MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField'
|
||||
@@ -78,3 +82,116 @@ def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile])
|
||||
if sym is None:
|
||||
return None
|
||||
return sym.node
|
||||
|
||||
|
||||
def parse_bool(expr: Expression) -> Optional[bool]:
|
||||
if isinstance(expr, NameExpr):
|
||||
if expr.fullname == 'builtins.True':
|
||||
return True
|
||||
if expr.fullname == 'builtins.False':
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
|
||||
return Instance(instance.type, args=new_typevars)
|
||||
|
||||
|
||||
def fill_typevars_with_any(instance: Instance) -> Type:
|
||||
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
|
||||
|
||||
|
||||
def extract_typevar_value(tp: Instance, typevar_name: str):
|
||||
return tp.args[tp.type.type_vars.index(typevar_name)]
|
||||
|
||||
|
||||
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
|
||||
if tp.type.has_base(FIELD_FULLNAME):
|
||||
set_method = tp.type.get_method('__set__')
|
||||
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
|
||||
if 'value' in set_method.type.arg_names:
|
||||
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
|
||||
if isinstance(set_value_type, Instance):
|
||||
typevar_values: typing.List[Type] = []
|
||||
for typevar_arg in set_value_type.args:
|
||||
if isinstance(typevar_arg, TypeVarType):
|
||||
typevar_values.append(extract_typevar_value(tp, typevar_arg.name))
|
||||
# if there are typevars, extract from
|
||||
set_value_type = reparametrize_with(set_value_type, typevar_values)
|
||||
return set_value_type
|
||||
|
||||
get_method = tp.type.get_method('__get__')
|
||||
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
|
||||
return get_method.type.ret_type
|
||||
# GenericForeignKey
|
||||
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
|
||||
return AnyType(TypeOfAny.special_form)
|
||||
return None
|
||||
|
||||
|
||||
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
|
||||
# only primary keys defined in current class for now
|
||||
for sym in model.names.values():
|
||||
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
|
||||
tp = sym.node.type
|
||||
if tp.type.metadata.get('django', {}).get('defined_as_primary_key'):
|
||||
field_type = extract_field_setter_type(tp)
|
||||
return field_type
|
||||
return None
|
||||
|
||||
|
||||
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
|
||||
expected_types: Dict[str, Type] = {}
|
||||
for base in model.mro:
|
||||
for name, sym in base.names.items():
|
||||
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
|
||||
tp = sym.node.type
|
||||
field_type = extract_field_setter_type(tp)
|
||||
if tp.type.fullname() == FOREIGN_KEY_FULLNAME:
|
||||
ref_to_model = tp.args[0]
|
||||
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME):
|
||||
primary_key_type = extract_primary_key_type(ref_to_model.type)
|
||||
if not primary_key_type:
|
||||
primary_key_type = AnyType(TypeOfAny.special_form)
|
||||
expected_types[name + '_id'] = primary_key_type
|
||||
if field_type:
|
||||
expected_types[name] = field_type
|
||||
|
||||
primary_key_type = extract_primary_key_type(model)
|
||||
if not primary_key_type:
|
||||
# no explicit primary key, set pk to Any and add id
|
||||
primary_key_type = AnyType(TypeOfAny.special_form)
|
||||
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
|
||||
|
||||
expected_types['pk'] = primary_key_type
|
||||
return expected_types
|
||||
|
||||
|
||||
def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]:
|
||||
"""Return the expression for the specific argument.
|
||||
|
||||
This helper should only be used with non-star arguments.
|
||||
"""
|
||||
if name not in ctx.callee_arg_names:
|
||||
return None
|
||||
idx = ctx.callee_arg_names.index(name)
|
||||
args = ctx.args[idx]
|
||||
if len(args) != 1:
|
||||
# Either an error or no value passed.
|
||||
return None
|
||||
return args[0]
|
||||
|
||||
|
||||
def get_argument_type_by_name(ctx: FunctionContext, name: str) -> Optional[Type]:
|
||||
"""Return the type for the specific argument.
|
||||
|
||||
This helper should only be used with non-star arguments.
|
||||
"""
|
||||
if name not in ctx.callee_arg_names:
|
||||
return None
|
||||
idx = ctx.callee_arg_names.index(name)
|
||||
arg_types = ctx.arg_types[idx]
|
||||
if len(arg_types) != 1:
|
||||
# Either an error or no value passed.
|
||||
return None
|
||||
return arg_types[0]
|
||||
|
||||
@@ -8,6 +8,7 @@ from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
|
||||
from mypy.types import Instance, Type
|
||||
|
||||
from mypy_django_plugin import helpers, monkeypatch
|
||||
from mypy_django_plugin.helpers import parse_bool
|
||||
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
|
||||
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
|
||||
from mypy_django_plugin.plugins.models import process_model_class
|
||||
@@ -54,6 +55,40 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
|
||||
return ret
|
||||
|
||||
|
||||
def redefine_model_init(ctx: FunctionContext) -> Type:
|
||||
assert isinstance(ctx.default_return_type, Instance)
|
||||
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
model: TypeInfo = ctx.default_return_type.type
|
||||
|
||||
expected_types = helpers.extract_expected_types(ctx, model)
|
||||
for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]):
|
||||
if actual_name is None:
|
||||
# We can't check kwargs reliably.
|
||||
continue
|
||||
if actual_name not in expected_types:
|
||||
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
|
||||
model.name()),
|
||||
ctx.context)
|
||||
continue
|
||||
api.check_subtype(actual_type, expected_types[actual_name],
|
||||
ctx.context,
|
||||
'Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
model.name()),
|
||||
'got', 'expected')
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def set_primary_key_marking(ctx: FunctionContext) -> Type:
|
||||
primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key')
|
||||
if primary_key_arg:
|
||||
is_primary_key = parse_bool(primary_key_arg)
|
||||
if is_primary_key:
|
||||
info = ctx.default_return_type.type
|
||||
info.metadata.setdefault('django', {})['defined_as_primary_key'] = True
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
class DjangoPlugin(Plugin):
|
||||
def __init__(self, options: Options) -> None:
|
||||
super().__init__(options)
|
||||
@@ -105,6 +140,13 @@ class DjangoPlugin(Plugin):
|
||||
if fullname in manager_bases:
|
||||
return determine_proper_manager_type
|
||||
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
if sym and isinstance(sym.node, TypeInfo):
|
||||
if sym.node.has_base(helpers.FIELD_FULLNAME):
|
||||
return set_primary_key_marking
|
||||
if sym.node.metadata.get('django', {}).get('generated_init'):
|
||||
return redefine_model_init
|
||||
|
||||
def get_method_hook(self, fullname: str
|
||||
) -> Optional[Callable[[MethodContext], Type]]:
|
||||
if fullname in {'django.apps.registry.Apps.get_model',
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type, Instance
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
|
||||
if 'base_field' not in ctx.callee_arg_names:
|
||||
return ctx.default_return_type
|
||||
|
||||
base_field_arg_type = ctx.arg_types[ctx.callee_arg_names.index('base_field')][0]
|
||||
if not isinstance(base_field_arg_type, Instance):
|
||||
base_field_arg_type = helpers.get_argument_type_by_name(ctx, 'base_field')
|
||||
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
|
||||
return ctx.default_return_type
|
||||
|
||||
get_method = base_field_arg_type.type.get_method('__get__')
|
||||
|
||||
@@ -2,11 +2,12 @@ from abc import ABCMeta, abstractmethod
|
||||
from typing import Dict, Iterator, Optional, Tuple, cast
|
||||
|
||||
import dataclasses
|
||||
from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, \
|
||||
StrExpr, SymbolTableNode, TypeInfo, Var
|
||||
from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, \
|
||||
MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, Argument, ARG_STAR2
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.plugins.common import add_method
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Instance
|
||||
from mypy.types import Instance, AnyType, TypeOfAny, NoneTyp
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
@@ -199,16 +200,27 @@ def extract_ref_to_fullname(rvalue_expr: CallExpr,
|
||||
return None
|
||||
|
||||
|
||||
def add_dummy_init_method(ctx: ClassDefContext) -> None:
|
||||
any = AnyType(TypeOfAny.special_form)
|
||||
var = Var('kwargs', any)
|
||||
kw_arg = Argument(variable=var, type_annotation=any, initializer=None, kind=ARG_STAR2)
|
||||
add_method(ctx, '__init__', [kw_arg], NoneTyp())
|
||||
# mark as model class
|
||||
ctx.cls.info.metadata.setdefault('django', {})['generated_init'] = True
|
||||
|
||||
|
||||
def process_model_class(ctx: ClassDefContext) -> None:
|
||||
initializers = [
|
||||
InjectAnyAsBaseForNestedMeta,
|
||||
AddDefaultObjectsManager,
|
||||
AddIdAttributeIfPrimaryKeyTrueIsNotSet,
|
||||
SetIdAttrsForRelatedFields,
|
||||
AddRelatedManagers
|
||||
AddRelatedManagers,
|
||||
]
|
||||
for initializer_cls in initializers:
|
||||
initializer_cls.from_ctx(ctx).run()
|
||||
|
||||
add_dummy_init_method(ctx)
|
||||
|
||||
# allow unspecified attributes for now
|
||||
ctx.cls.info.fallback_to_any = True
|
||||
|
||||
@@ -1,20 +1,12 @@
|
||||
import typing
|
||||
from typing import Optional, cast
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import StrExpr, TypeInfo
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
|
||||
from mypy.types import CallableType, Instance, Type
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
|
||||
return Instance(instance.type, args=new_typevars)
|
||||
|
||||
|
||||
def fill_typevars_with_any(instance: Instance) -> Type:
|
||||
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
|
||||
from mypy_django_plugin.helpers import fill_typevars_with_any, reparametrize_with
|
||||
|
||||
|
||||
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
|
||||
|
||||
@@ -29,6 +29,7 @@ IGNORED_ERRORS = {
|
||||
'Invalid value for a to= parameter',
|
||||
'already defined (possibly by an import)',
|
||||
'Cannot assign to a type',
|
||||
re.compile(r'Cannot assign to class variable "[a-z_]+" via instance'),
|
||||
# forms <-> models plugin support
|
||||
'"Model" has no attribute',
|
||||
re.compile(r'Cannot determine type of \'(objects|stuff)\''),
|
||||
@@ -73,7 +74,8 @@ IGNORED_ERRORS = {
|
||||
'admin_views': [
|
||||
'Argument 1 to "FileWrapper" has incompatible type "StringIO"; expected "IO[bytes]"',
|
||||
'Incompatible types in assignment',
|
||||
'"object" not callable'
|
||||
'"object" not callable',
|
||||
'Incompatible type for "pk" of "Collector" (got "int", expected "str")'
|
||||
],
|
||||
'aggregation': [
|
||||
'Incompatible types in assignment (expression has type "QuerySet[Any]", variable has type "List[Any]")',
|
||||
|
||||
106
test-data/typecheck/model_create.test
Normal file
106
test-data/typecheck/model_create.test
Normal file
@@ -0,0 +1,106 @@
|
||||
[CASE arguments_to_init_unexpected_attributes]
|
||||
from django.db import models
|
||||
|
||||
class MyUser(models.Model):
|
||||
pass
|
||||
user = MyUser(name=1, age=12)
|
||||
[out]
|
||||
main:5: error: Unexpected attribute "name" for model "MyUser"
|
||||
main:5: error: Unexpected attribute "age" for model "MyUser"
|
||||
|
||||
[CASE arguments_to_init_from_class_incompatible_type]
|
||||
from django.db import models
|
||||
|
||||
class MyUser(models.Model):
|
||||
name = models.CharField(max_length=100)
|
||||
age = models.IntegerField()
|
||||
user = MyUser(name=1, age=12)
|
||||
[out]
|
||||
main:6: error: Incompatible type for "name" of "MyUser" (got "int", expected "str")
|
||||
|
||||
[CASE arguments_to_init_combined_from_base_classes]
|
||||
from django.db import models
|
||||
|
||||
class BaseUser(models.Model):
|
||||
name = models.CharField(max_length=100)
|
||||
age = models.IntegerField()
|
||||
class ChildUser(BaseUser):
|
||||
lastname = models.CharField(max_length=100)
|
||||
user = ChildUser(name='Max', age=12, lastname='Lastname')
|
||||
[out]
|
||||
|
||||
[CASE fields_from_abstract_user_propagate_to_init]
|
||||
from django.contrib.auth.models import AbstractUser
|
||||
|
||||
class MyUser(AbstractUser):
|
||||
pass
|
||||
user = MyUser(username='maxim', password='password', first_name='Max', last_name='MaxMax')
|
||||
[out]
|
||||
|
||||
[CASE generic_foreign_key_field_no_typechecking]
|
||||
from django.db import models
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
|
||||
class MyUser(models.Model):
|
||||
content_object = GenericForeignKey()
|
||||
|
||||
user = MyUser(content_object=1)
|
||||
[out]
|
||||
|
||||
[CASE pk_refers_to_primary_key_and_could_be_passed_to_init]
|
||||
from django.db import models
|
||||
|
||||
class MyUser1(models.Model):
|
||||
mypk = models.CharField(primary_key=True)
|
||||
class MyUser2(models.Model):
|
||||
pass
|
||||
user2 = MyUser1(pk='hello')
|
||||
user3= MyUser2(pk=1)
|
||||
[out]
|
||||
|
||||
[CASE typechecking_of_pk]
|
||||
from django.db import models
|
||||
|
||||
class MyUser1(models.Model):
|
||||
mypk = models.CharField(primary_key=True)
|
||||
user = MyUser1(pk=1) # E: Incompatible type for "pk" of "MyUser1" (got "int", expected "str")
|
||||
[out]
|
||||
|
||||
[CASE can_set_foreign_key_by_its_primary_key]
|
||||
from django.db import models
|
||||
|
||||
class Publisher(models.Model):
|
||||
pass
|
||||
class PublisherWithCharPK(models.Model):
|
||||
id = models.CharField(max_length=100, primary_key=True)
|
||||
class Book(models.Model):
|
||||
publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE)
|
||||
publisher_with_char_pk = models.ForeignKey(PublisherWithCharPK, on_delete=models.CASCADE)
|
||||
|
||||
Book(publisher_id=1, publisher_with_char_pk_id='hello')
|
||||
Book(publisher_id=1, publisher_with_char_pk_id=1) # E: Incompatible type for "publisher_with_char_pk_id" of "Book" (got "int", expected "str")
|
||||
[out]
|
||||
|
||||
[CASE setting_value_to_an_array_of_ints]
|
||||
from typing import List, Tuple
|
||||
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
||||
class MyModel(models.Model):
|
||||
array = ArrayField(base_field=models.IntegerField())
|
||||
array_val: Tuple[int, ...] = (1,)
|
||||
MyModel(array=array_val)
|
||||
array_val2: List[int] = [1]
|
||||
MyModel(array=array_val2)
|
||||
array_val3: List[str] = ['hello']
|
||||
MyModel(array=array_val3) # E: Incompatible type for "array" of "MyModel" (got "List[str]", expected "Sequence[int]")
|
||||
[out]
|
||||
|
||||
[CASE if_no_explicit_primary_key_id_can_be_passed]
|
||||
from django.db import models
|
||||
|
||||
class MyModel(models.Model):
|
||||
pass
|
||||
MyModel(id=1)
|
||||
[out]
|
||||
Reference in New Issue
Block a user