preliminary support for strict_optional

This commit is contained in:
Maxim Kurnikov
2019-02-17 18:07:53 +03:00
parent 6763217a80
commit e9f9202ed1
23 changed files with 614 additions and 343 deletions

View File

@@ -25,7 +25,7 @@ class LogEntryManager(models.Manager["LogEntry"]):
class LogEntry(models.Model):
action_time: models.DateTimeField = ...
user: models.ForeignKey = ...
content_type: models.ForeignKey[ContentType] = ...
content_type: models.ForeignKey = models.ForeignKey(ContentType, on_delete=models.CASCADE)
object_id: models.TextField = ...
object_repr: models.CharField = ...
action_flag: models.PositiveSmallIntegerField = ...

View File

@@ -15,7 +15,7 @@ class PermissionManager(models.Manager):
class Permission(models.Model):
content_type_id: int
name: models.CharField = ...
content_type: models.ForeignKey[ContentType] = ...
content_type: models.ForeignKey = models.ForeignKey(ContentType, on_delete=models.CASCADE)
codename: models.CharField = ...
def natural_key(self) -> Tuple[str, str, str]: ...
@@ -24,7 +24,7 @@ class GroupManager(models.Manager):
class Group(models.Model):
name: models.CharField = ...
permissions: models.ManyToManyField[Permission] = ...
permissions: models.ManyToManyField = models.ManyToManyField(Permission)
def natural_key(self): ...
class UserManager(BaseUserManager):
@@ -37,8 +37,8 @@ class UserManager(BaseUserManager):
class PermissionsMixin(models.Model):
is_superuser: models.BooleanField = ...
groups: models.ManyToManyField[Group] = ...
user_permissions: models.ManyToManyField[Permission] = ...
groups: models.ManyToManyField = models.ManyToManyField(Group)
user_permissions: models.ManyToManyField = models.ManyToManyField(Permission)
def get_group_permissions(self, obj: None = ...) -> Set[str]: ...
def get_all_permissions(self, obj: Optional[str] = ...) -> Set[str]: ...
def has_perm(self, perm: Union[Tuple[str, Any], str], obj: Optional[str] = ...) -> bool: ...

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, Generic
from django.contrib.contenttypes.models import ContentType
from django.core.checks.messages import Error

View File

@@ -1,20 +1,51 @@
from typing import Any, Generic, List, Optional, Sequence, TypeVar
from typing import Any, Iterable, List, Optional, Sequence, TypeVar, Union
from django.db.models.expressions import Combinable
from django.db.models.fields import Field, _ErrorMessagesToOverride, _FieldChoices, _ValidatorCallable
from django.db.models.fields import Field
from .mixins import CheckFieldDefaultMixin
_T = TypeVar("_T", bound=Field)
# __set__ value type
_ST = TypeVar("_ST")
# __get__ return type
_GT = TypeVar("_GT")
class ArrayField(CheckFieldDefaultMixin, Field[_ST, _GT]):
_pyi_private_set_type: Union[Sequence[Any], Combinable]
_pyi_private_get_type: List[Any]
class ArrayField(CheckFieldDefaultMixin, Field, Generic[_T]):
empty_strings_allowed: bool = ...
default_error_messages: Any = ...
base_field: Any = ...
size: Any = ...
default_validators: Any = ...
from_db_value: Any = ...
def __init__(self, base_field: _T, size: Optional[int] = ..., **kwargs: Any) -> None: ...
def __init__(
self,
base_field: Field,
size: Optional[int] = ...,
verbose_name: Optional[Union[str, bytes]] = ...,
name: Optional[str] = ...,
primary_key: bool = ...,
max_length: Optional[int] = ...,
unique: bool = ...,
blank: bool = ...,
null: bool = ...,
db_index: bool = ...,
default: Any = ...,
editable: bool = ...,
auto_created: bool = ...,
serialize: bool = ...,
unique_for_date: Optional[str] = ...,
unique_for_month: Optional[str] = ...,
unique_for_year: Optional[str] = ...,
choices: Optional[_FieldChoices] = ...,
help_text: str = ...,
db_column: Optional[str] = ...,
db_tablespace: Optional[str] = ...,
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
) -> None: ...
@property
def description(self): ...
def get_transform(self, name: Any): ...
def __set__(self, instance, value: Sequence[_T]) -> None: ...
def __get__(self, instance, owner) -> List[_T]: ...

View File

@@ -1,16 +1,15 @@
import uuid
from datetime import date, time, datetime, timedelta
from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar, Generic
import decimal
from typing_extensions import Literal
import uuid
from datetime import date, datetime, time, timedelta
from typing import Any, Callable, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union
from django.db.models import Model
from django.db.models.query_utils import RegisterLookupMixin
from django.db.models.expressions import F, Combinable
from django.core.exceptions import FieldDoesNotExist as FieldDoesNotExist
from django.forms import Widget, Field as FormField
from django.db.models.expressions import Combinable
from django.db.models.query_utils import RegisterLookupMixin
from django.forms import Field as FormField, Widget
from typing_extensions import Literal
from .mixins import NOT_PROVIDED as NOT_PROVIDED
_Choice = Tuple[Any, Any]
@@ -20,7 +19,15 @@ _FieldChoices = Iterable[Union[_Choice, _ChoiceNamedGroup]]
_ValidatorCallable = Callable[..., None]
_ErrorMessagesToOverride = Dict[str, Any]
class Field(RegisterLookupMixin):
# __set__ value type
_ST = TypeVar("_ST")
# __get__ return type
_GT = TypeVar("_GT")
class Field(RegisterLookupMixin, Generic[_ST, _GT]):
_pyi_private_set_type: Any
_pyi_private_get_type: Any
widget: Widget
help_text: str
db_table: str
@@ -52,7 +59,8 @@ class Field(RegisterLookupMixin):
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
def __get__(self, instance, owner) -> Any: ...
def __set__(self, instance, value: _ST) -> None: ...
def __get__(self, instance, owner) -> _GT: ...
def deconstruct(self) -> Any: ...
def set_attributes_from_name(self, name: str) -> None: ...
def db_type(self, connection: Any) -> str: ...
@@ -63,23 +71,25 @@ class Field(RegisterLookupMixin):
def contribute_to_class(self, cls: Type[Model], name: str, private_only: bool = ...) -> None: ...
def to_python(self, value: Any) -> Any: ...
class IntegerField(Field):
def __set__(self, instance, value: Union[int, Combinable, Literal[""]]) -> None: ...
def __get__(self, instance, owner) -> int: ...
class IntegerField(Field[_ST, _GT]):
_pyi_private_set_type: Union[int, Combinable, Literal[""]]
_pyi_private_get_type: int
class PositiveIntegerRelDbTypeMixin:
def rel_db_type(self, connection: Any): ...
class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField): ...
class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField): ...
class SmallIntegerField(IntegerField): ...
class BigIntegerField(IntegerField): ...
class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ...
class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ...
class SmallIntegerField(IntegerField[_ST, _GT]): ...
class BigIntegerField(IntegerField[_ST, _GT]): ...
class FloatField(Field):
def __set__(self, instance, value: Union[float, int, str, Combinable]) -> float: ...
def __get__(self, instance, owner) -> float: ...
class FloatField(Field[_ST, _GT]):
_pyi_private_set_type: Union[float, int, str, Combinable]
_pyi_private_get_type: float
class DecimalField(Field):
class DecimalField(Field[_ST, _GT]):
_pyi_private_set_type: Union[str, float, decimal.Decimal, Combinable]
_pyi_private_get_type: decimal.Decimal
def __init__(
self,
verbose_name: Optional[Union[str, bytes]] = ...,
@@ -102,13 +112,14 @@ class DecimalField(Field):
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
def __set__(self, instance, value: Union[str, float, decimal.Decimal, Combinable]) -> decimal.Decimal: ...
def __get__(self, instance, owner) -> decimal.Decimal: ...
class AutoField(Field):
def __get__(self, instance, owner) -> int: ...
class AutoField(Field[_ST, _GT]):
_pyi_private_set_type: Union[Combinable, int, str]
_pyi_private_get_type: int
class CharField(Field):
class CharField(Field[_ST, _GT]):
_pyi_private_set_type: Union[str, int, Combinable]
_pyi_private_get_type: str
def __init__(
self,
verbose_name: Optional[Union[str, bytes]] = ...,
@@ -133,10 +144,8 @@ class CharField(Field):
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
def __set__(self, instance, value: Union[str, int, Combinable]) -> None: ...
def __get__(self, instance, owner) -> str: ...
class SlugField(CharField):
class SlugField(CharField[_ST, _GT]):
def __init__(
self,
verbose_name: Optional[Union[str, bytes]] = ...,
@@ -163,25 +172,29 @@ class SlugField(CharField):
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
class EmailField(CharField): ...
class URLField(CharField): ...
class EmailField(CharField[_ST, _GT]): ...
class URLField(CharField[_ST, _GT]): ...
class TextField(Field):
def __set__(self, instance, value: Union[str, Combinable]) -> None: ...
def __get__(self, instance, owner) -> str: ...
class TextField(Field[_ST, _GT]):
_pyi_private_set_type: Union[str, Combinable]
_pyi_private_get_type: str
class BooleanField(Field):
def __set__(self, instance, value: Union[bool, Combinable]) -> None: ...
def __get__(self, instance, owner) -> bool: ...
class BooleanField(Field[_ST, _GT]):
_pyi_private_set_type: Union[bool, Combinable]
_pyi_private_get_type: bool
class NullBooleanField(Field):
def __set__(self, instance, value: Optional[Union[bool, Combinable]]) -> None: ...
def __get__(self, instance, owner) -> Optional[bool]: ...
class NullBooleanField(Field[_ST, _GT]):
_pyi_private_set_type: Optional[Union[bool, Combinable]]
_pyi_private_get_type: Optional[bool]
class IPAddressField(Field):
def __get__(self, instance, owner) -> str: ...
class IPAddressField(Field[_ST, _GT]):
_pyi_private_set_type: Union[str, Combinable]
_pyi_private_get_type: str
class GenericIPAddressField(Field[_ST, _GT]):
_pyi_private_set_type: Union[str, int, Callable[..., Any], Combinable]
_pyi_private_get_type: str
class GenericIPAddressField(Field):
default_error_messages: Any = ...
unpack_ipv4: Any = ...
protocol: Any = ...
@@ -207,12 +220,12 @@ class GenericIPAddressField(Field):
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
) -> None: ...
def __set__(self, instance, value: Union[str, int, Callable[..., Any], Combinable]): ...
def __get__(self, instance, owner) -> str: ...
class DateTimeCheckMixin: ...
class DateField(DateTimeCheckMixin, Field):
class DateField(DateTimeCheckMixin, Field[_ST, _GT]):
_pyi_private_set_type: Union[str, date, datetime, Combinable]
_pyi_private_get_type: date
def __init__(
self,
verbose_name: Optional[Union[str, bytes]] = ...,
@@ -236,10 +249,10 @@ class DateField(DateTimeCheckMixin, Field):
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
def __set__(self, instance, value: Union[str, date, Combinable]) -> None: ...
def __get__(self, instance, owner) -> date: ...
class TimeField(DateTimeCheckMixin, Field):
class TimeField(DateTimeCheckMixin, Field[_ST, _GT]):
_pyi_private_set_type: Union[str, time, datetime, Combinable]
_pyi_private_get_type: time
def __init__(
self,
verbose_name: Optional[Union[str, bytes]] = ...,
@@ -262,18 +275,15 @@ class TimeField(DateTimeCheckMixin, Field):
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
def __set__(self, instance, value: Union[str, time, datetime, Combinable]) -> None: ...
def __get__(self, instance, owner) -> time: ...
class DateTimeField(DateField):
def __set__(self, instance, value: Union[str, date, datetime, Combinable]) -> None: ...
def __get__(self, instance, owner) -> datetime: ...
class DateTimeField(DateField[_ST, _GT]):
_pyi_private_get_type: datetime
class UUIDField(Field):
def __set__(self, instance, value: Union[str, uuid.UUID]) -> None: ...
def __get__(self, instance, owner) -> uuid.UUID: ...
class UUIDField(Field[_ST, _GT]):
_pyi_private_set_type: Union[str, uuid.UUID]
_pyi_private_get_type: uuid.UUID
class FilePathField(Field):
class FilePathField(Field[_ST, _GT]):
path: str = ...
match: Optional[Any] = ...
recursive: bool = ...
@@ -306,10 +316,10 @@ class FilePathField(Field):
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
class BinaryField(Field): ...
class BinaryField(Field[_ST, _GT]): ...
class DurationField(Field):
def __get__(self, instance, owner) -> timedelta: ...
class DurationField(Field[_ST, _GT]):
_pyi_private_get_type: timedelta
class BigAutoField(AutoField): ...
class CommaSeparatedIntegerField(CharField): ...
class BigAutoField(AutoField[_ST, _GT]): ...
class CommaSeparatedIntegerField(CharField[_ST, _GT]): ...

View File

@@ -49,7 +49,12 @@ _ErrorMessagesToOverride = Dict[str, Any]
RECURSIVE_RELATIONSHIP_CONSTANT: str = ...
class RelatedField(FieldCacheMixin, Field):
# __set__ value type
_ST = TypeVar("_ST")
# __get__ return type
_GT = TypeVar("_GT")
class RelatedField(FieldCacheMixin, Field[_ST, _GT]):
one_to_many: bool = ...
one_to_one: bool = ...
many_to_many: bool = ...
@@ -83,6 +88,7 @@ class ForeignObject(RelatedField):
related_query_name: None = ...,
limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any]]] = ...,
parent_link: bool = ...,
db_constraint: bool = ...,
swappable: bool = ...,
verbose_name: Optional[str] = ...,
name: Optional[str] = ...,
@@ -103,17 +109,82 @@ class ForeignObject(RelatedField):
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
class ForeignKey(RelatedField, Generic[_T]):
def __init__(self, to: Union[Type[_T], str], on_delete: Any, related_name: str = ..., **kwargs): ...
def __set__(self, instance, value: Union[Model, Combinable]) -> None: ...
def __get__(self, instance, owner) -> _T: ...
class ForeignKey(RelatedField[_ST, _GT]):
_pyi_private_set_type: Union[Any, Combinable]
_pyi_private_get_type: Any
def __init__(
self,
to: Union[Type[Model], str],
on_delete: Callable[..., None],
to_field: Optional[str] = ...,
related_name: str = ...,
related_query_name: Optional[str] = ...,
limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any], Q]] = ...,
parent_link: bool = ...,
db_constraint: bool = ...,
verbose_name: Optional[Union[str, bytes]] = ...,
name: Optional[str] = ...,
primary_key: bool = ...,
max_length: Optional[int] = ...,
unique: bool = ...,
blank: bool = ...,
null: bool = ...,
db_index: bool = ...,
default: Any = ...,
editable: bool = ...,
auto_created: bool = ...,
serialize: bool = ...,
unique_for_date: Optional[str] = ...,
unique_for_month: Optional[str] = ...,
unique_for_year: Optional[str] = ...,
choices: Optional[_FieldChoices] = ...,
help_text: str = ...,
db_column: Optional[str] = ...,
db_tablespace: Optional[str] = ...,
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
class OneToOneField(RelatedField, Generic[_T]):
def __init__(self, to: Union[Type[_T], str], on_delete: Any, related_name: str = ..., **kwargs): ...
def __set__(self, instance, value: Union[Model, Combinable]) -> None: ...
def __get__(self, instance, owner) -> _T: ...
class OneToOneField(RelatedField[_ST, _GT]):
_pyi_private_set_type: Union[Any, Combinable]
_pyi_private_get_type: Any
def __init__(
self,
to: Union[Type[Model], str],
on_delete: Any,
to_field: Optional[str] = ...,
related_name: str = ...,
related_query_name: Optional[str] = ...,
limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any], Q]] = ...,
parent_link: bool = ...,
db_constraint: bool = ...,
verbose_name: Optional[Union[str, bytes]] = ...,
name: Optional[str] = ...,
primary_key: bool = ...,
max_length: Optional[int] = ...,
unique: bool = ...,
blank: bool = ...,
null: bool = ...,
db_index: bool = ...,
default: Any = ...,
editable: bool = ...,
auto_created: bool = ...,
serialize: bool = ...,
unique_for_date: Optional[str] = ...,
unique_for_month: Optional[str] = ...,
unique_for_year: Optional[str] = ...,
choices: Optional[_FieldChoices] = ...,
help_text: str = ...,
db_column: Optional[str] = ...,
db_tablespace: Optional[str] = ...,
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
class ManyToManyField(RelatedField[_ST, _GT]):
_pyi_private_set_type: Sequence[Any]
_pyi_private_get_type: RelatedManager[Any]
class ManyToManyField(RelatedField, Generic[_T]):
many_to_many: bool = ...
many_to_one: bool = ...
one_to_many: bool = ...
@@ -127,17 +198,35 @@ class ManyToManyField(RelatedField, Generic[_T]):
to: Union[Type[_T], str],
related_name: Optional[str] = ...,
related_query_name: Optional[str] = ...,
limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any]]] = ...,
limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any], Q]] = ...,
symmetrical: Optional[bool] = ...,
through: Optional[Union[str, Type[Model]]] = ...,
through_fields: Optional[Tuple[str, str]] = ...,
db_constraint: bool = ...,
db_table: Optional[str] = ...,
swappable: bool = ...,
**kwargs: Any
verbose_name: Optional[Union[str, bytes]] = ...,
name: Optional[str] = ...,
primary_key: bool = ...,
max_length: Optional[int] = ...,
unique: bool = ...,
blank: bool = ...,
null: bool = ...,
db_index: bool = ...,
default: Any = ...,
editable: bool = ...,
auto_created: bool = ...,
serialize: bool = ...,
unique_for_date: Optional[str] = ...,
unique_for_month: Optional[str] = ...,
unique_for_year: Optional[str] = ...,
choices: Optional[_FieldChoices] = ...,
help_text: str = ...,
db_column: Optional[str] = ...,
db_tablespace: Optional[str] = ...,
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
) -> None: ...
def __set__(self, instance, value: Sequence[_T]) -> None: ...
def __get__(self, instance, owner) -> RelatedManager[_T]: ...
def check(self, **kwargs: Any) -> List[Any]: ...
def deconstruct(self) -> Tuple[Optional[str], str, List[Any], Dict[str, str]]: ...
def get_path_info(self, filtered_relation: None = ...) -> List[PathInfo]: ...

View File

@@ -5,10 +5,12 @@ from mypy.checker import TypeChecker
from mypy.nodes import AssignmentStmt, ClassDef, Expression, FuncDef, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, \
TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType
from mypy.types import AnyType, CallableType, Instance, NoneTyp, Type, TypeOfAny, TypeVarType, UnionType
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field'
ARRAY_FIELD_FULLNAME = 'django.contrib.postgres.fields.array.ArrayField'
AUTO_FIELD_FULLNAME = 'django.db.models.fields.AutoField'
GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
@@ -95,12 +97,13 @@ def parse_bool(expr: Expression) -> Optional[bool]:
return None
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
return Instance(instance.type, args=new_typevars)
def reparametrize_instance(instance: Instance, new_args: typing.List[Type]) -> Instance:
return Instance(instance.type, args=new_args,
line=instance.line, column=instance.column)
def fill_typevars_with_any(instance: Instance) -> Type:
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
def fill_typevars_with_any(instance: Instance) -> Instance:
return reparametrize_instance(instance, [AnyType(TypeOfAny.unannotated)])
def extract_typevar_value(tp: Instance, typevar_name: str) -> Type:
@@ -117,7 +120,7 @@ def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance:
for typevar_arg in type_to_fill.args:
if isinstance(typevar_arg, TypeVarType):
typevar_values.append(extract_typevar_value(tp, typevar_arg.name))
return reparametrize_with(type_to_fill, typevar_values)
return Instance(type_to_fill.type, typevar_values)
def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]:
@@ -189,28 +192,12 @@ def iter_over_assignments(
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
""" Extract __set__ value of a field. """
if tp.type.has_base(FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
set_value_type = fill_typevars(tp, set_value_type)
return set_value_type
elif isinstance(set_value_type, UnionType):
items_no_typevars = []
for item in set_value_type.items:
if isinstance(item, Instance):
item = fill_typevars(tp, item)
items_no_typevars.append(item)
return UnionType(items_no_typevars)
field_getter_type = extract_field_getter_type(tp)
if field_getter_type:
return field_getter_type
return tp.args[0]
# GenericForeignKey
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
@@ -218,9 +205,7 @@ def extract_field_getter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
if tp.type.has_base(FIELD_FULLNAME):
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
return get_method.type.ret_type
return tp.args[1]
# GenericForeignKey
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
@@ -240,7 +225,10 @@ def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
return get_django_metadata(model).setdefault('fields', {})
def extract_primary_key_type_for_set(model: TypeInfo) -> Optional[Type]:
def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[Type]:
"""
If field with primary_key=True is set on the model, extract its __set__ type.
"""
for field_name, props in get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
@@ -254,3 +242,30 @@ def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]:
if is_primary_key:
return extract_field_getter_type(model.names[field_name].type)
return None
def make_optional(typ: Type):
return UnionType.make_simplified_union([typ, NoneTyp()])
def make_required(typ: Type) -> Type:
if not isinstance(typ, UnionType):
return typ
items = [item for item in typ.items if not isinstance(item, NoneTyp)]
# will reduce to Instance, if only one item
return UnionType.make_union(items)
def is_optional(typ: Type) -> bool:
if not isinstance(typ, UnionType):
return False
return any([isinstance(item, NoneTyp) for item in typ.items])
def has_any_of_bases(info: TypeInfo, bases: typing.Sequence[str]) -> bool:
for base_fullname in bases:
if info.has_base(base_fullname):
return True
return False

View File

@@ -1,19 +1,18 @@
import os
from typing import Callable, Dict, Optional, cast
from typing import Callable, Dict, Optional, Union, cast
from mypy.checker import TypeChecker
from mypy.nodes import MemberExpr, TypeInfo
from mypy.options import Options
from mypy.plugin import AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType, UnionType
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.config import Config
from mypy_django_plugin.plugins import init_create
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations, get_string_value_from_expr
from mypy_django_plugin.plugins.models import process_model_class
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
from mypy_django_plugin.plugins.settings import AddSettingValuesToDjangoConfObject, get_settings_metadata
from mypy_django_plugin.transformers import fields, init_create
from mypy_django_plugin.transformers.migrations import determine_model_cls_from_string_for_migrations, \
get_string_value_from_expr
from mypy_django_plugin.transformers.models import process_model_class
from mypy_django_plugin.transformers.settings import AddSettingValuesToDjangoConfObject, get_settings_metadata
def transform_model_class(ctx: ClassDefContext) -> None:
@@ -50,7 +49,7 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
if base.type.fullname() in {helpers.MANAGER_CLASS_FULLNAME,
helpers.RELATED_MANAGER_CLASS_FULLNAME,
helpers.BASE_MANAGER_CLASS_FULLNAME}:
ret.type.bases[i] = reparametrize_with(base, [Instance(outer_model_info, [])])
ret.type.bases[i] = Instance(base.type, [Instance(outer_model_info, [])])
return ret
return ret
@@ -84,6 +83,17 @@ def return_user_model_hook(ctx: FunctionContext) -> Type:
return TypeType(Instance(model_info, []))
def _extract_referred_to_type_info(typ: Union[UnionType, Instance]) -> Optional[TypeInfo]:
if isinstance(typ, Instance):
return typ.type
else:
# should be Union[TYPE, None]
typ = helpers.make_required(typ)
if isinstance(typ, Instance):
return typ.type
return None
def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type:
if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'):
return ctx.default_attr_type
@@ -94,14 +104,22 @@ def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: Attribu
field_name = ctx.context.name.split('_')[0]
sym = ctx.type.type.get(field_name)
if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0:
to_arg = sym.type.args[0]
if isinstance(to_arg, AnyType):
return AnyType(TypeOfAny.special_form)
referred_to = sym.type.args[1]
if isinstance(referred_to, AnyType):
return AnyType(TypeOfAny.implementation_artifact)
model_type = _extract_referred_to_type_info(referred_to)
if model_type is None:
return AnyType(TypeOfAny.implementation_artifact)
model_type: TypeInfo = to_arg.type
primary_key_type = helpers.extract_primary_key_type_for_get(model_type)
if primary_key_type:
return primary_key_type
is_nullable = helpers.get_fields_metadata(ctx.type.type).get(field_name, {}).get('null', False)
if is_nullable:
return helpers.make_optional(ctx.default_attr_type)
return ctx.default_attr_type
@@ -179,26 +197,19 @@ class DjangoPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FIELD_FULLNAME):
return fields.adjust_return_type_of_field_instantiation
if fullname == 'django.contrib.auth.get_user_model':
return return_user_model_hook
if fullname in {helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME,
helpers.MANYTOMANY_FIELD_FULLNAME}:
return extract_to_parameter_as_get_ret_type_for_related_field
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field
manager_bases = self._get_current_manager_bases()
if fullname in manager_bases:
return determine_proper_manager_type
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
if sym.node.has_base(helpers.FIELD_FULLNAME):
return record_field_properties_into_outer_model_class
if sym.node.metadata.get('django', {}).get('generated_init'):
return init_create.redefine_and_typecheck_model_init

View File

@@ -1,74 +0,0 @@
from typing import cast
from mypy.checker import TypeChecker
from mypy.nodes import ListExpr, NameExpr, TupleExpr
from mypy.plugin import FunctionContext
from mypy.types import Instance, TupleType, Type
from mypy_django_plugin import helpers
from mypy_django_plugin.plugins.models import iter_over_assignments
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
base_field_arg_type = helpers.get_argument_type_by_name(ctx, 'base_field')
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
return ctx.default_return_type
get_method = base_field_arg_type.type.get_method('__get__')
if not get_method:
# not a method
return ctx.default_return_type
return ctx.api.named_generic_type(ctx.context.callee.fullname,
args=[get_method.type.ret_type])
def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
outer_model = api.scope.active_class()
if outer_model is None or not outer_model.has_base(helpers.MODEL_CLASS_FULLNAME):
# outside models.Model class, undetermined
return ctx.default_return_type
field_name = None
for name_expr, stmt in iter_over_assignments(outer_model.defn):
if stmt == ctx.context and isinstance(name_expr, NameExpr):
field_name = name_expr.name
break
if field_name is None:
return ctx.default_return_type
fields_metadata = outer_model.metadata.setdefault('django', {}).setdefault('fields', {})
# primary key
is_primary_key = False
primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key')
if primary_key_arg:
is_primary_key = helpers.parse_bool(primary_key_arg)
fields_metadata[field_name] = {'primary_key': is_primary_key}
# choices
choices_arg = helpers.get_argument_by_name(ctx, 'choices')
if choices_arg and isinstance(choices_arg, (TupleExpr, ListExpr)):
# iterable of 2 element tuples of two kinds
_, analyzed_choices = api.analyze_iterable_item_type(choices_arg)
if isinstance(analyzed_choices, TupleType):
first_element_type = analyzed_choices.items[0]
if isinstance(first_element_type, Instance):
fields_metadata[field_name]['choices'] = first_element_type.type.fullname()
# nullability
null_arg = helpers.get_argument_by_name(ctx, 'null')
is_nullable = False
if null_arg:
is_nullable = helpers.parse_bool(null_arg)
fields_metadata[field_name]['null'] = is_nullable
# is_blankable
blank_arg = helpers.get_argument_by_name(ctx, 'blank')
is_blankable = False
if blank_arg:
is_blankable = helpers.parse_bool(blank_arg)
fields_metadata[field_name]['blank'] = is_blankable
return ctx.default_return_type

View File

@@ -1,62 +0,0 @@
from typing import Optional, cast
from mypy.checker import TypeChecker
from mypy.nodes import StrExpr, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import CallableType, Instance, Type
from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import fill_typevars_with_any, reparametrize_with
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.callee_arg_names:
# shouldn't happen, invalid code
api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
context=ctx.context)
return None
arg_type = ctx.arg_types[ctx.callee_arg_names.index('to')][0]
if not isinstance(arg_type, CallableType):
to_arg_expr = ctx.args[ctx.callee_arg_names.index('to')][0]
if not isinstance(to_arg_expr, StrExpr):
# not string, not supported
return None
try:
model_fullname = helpers.get_model_fullname_from_string(to_arg_expr.value,
all_modules=api.modules)
except helpers.SelfReference:
model_fullname = api.tscope.classes[-1].fullname()
if model_fullname is None:
return None
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
all_modules=api.modules)
if model_info is None or not isinstance(model_info, TypeInfo):
return None
return Instance(model_info, [])
referred_to_type = arg_type.ret_type
if not isinstance(referred_to_type, Instance):
return None
if not referred_to_type.type.has_base(helpers.MODEL_CLASS_FULLNAME):
ctx.api.msg.fail(f'to= parameter value must be '
f'a subclass of {helpers.MODEL_CLASS_FULLNAME}',
context=ctx.context)
return None
return referred_to_type
def extract_to_parameter_as_get_ret_type_for_related_field(ctx: FunctionContext) -> Type:
try:
referred_to_type = get_valid_to_value_or_none(ctx)
except helpers.InvalidModelString as exc:
ctx.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', ctx.context)
return fill_typevars_with_any(ctx.default_return_type)
if referred_to_type is None:
# couldn't extract to= value
return fill_typevars_with_any(ctx.default_return_type)
return reparametrize_with(ctx.default_return_type, [referred_to_type])

View File

@@ -0,0 +1,199 @@
from typing import Optional, cast
from mypy.checker import TypeChecker
from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo, Var
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, TupleType, Type, TypeOfAny, UnionType
from mypy_django_plugin import helpers
from mypy_django_plugin.transformers.models import iter_over_assignments
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.callee_arg_names:
api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
context=ctx.context)
return None
arg_type = ctx.arg_types[ctx.callee_arg_names.index('to')][0]
if not isinstance(arg_type, CallableType):
to_arg_expr = ctx.args[ctx.callee_arg_names.index('to')][0]
if not isinstance(to_arg_expr, StrExpr):
# not string, not supported
return None
try:
model_fullname = helpers.get_model_fullname_from_string(to_arg_expr.value,
all_modules=api.modules)
except helpers.SelfReference:
model_fullname = api.tscope.classes[-1].fullname()
if model_fullname is None:
return None
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
all_modules=api.modules)
if model_info is None or not isinstance(model_info, TypeInfo):
return None
return Instance(model_info, [])
referred_to_type = arg_type.ret_type
if not isinstance(referred_to_type, Instance):
return None
if not referred_to_type.type.has_base(helpers.MODEL_CLASS_FULLNAME):
ctx.api.msg.fail(f'to= parameter value must be '
f'a subclass of {helpers.MODEL_CLASS_FULLNAME}',
context=ctx.context)
return None
return referred_to_type
def convert_any_to_type(typ: Type, referred_to_type: Type) -> Type:
if isinstance(typ, UnionType):
converted_items = []
for item in typ.items:
converted_items.append(convert_any_to_type(item, referred_to_type))
return UnionType.make_simplified_union(converted_items,
line=typ.line, column=typ.column)
if isinstance(typ, Instance):
args = []
for default_arg in typ.args:
if isinstance(default_arg, AnyType):
args.append(referred_to_type)
else:
args.append(default_arg)
return helpers.reparametrize_instance(typ, args)
if isinstance(typ, AnyType):
return referred_to_type
return typ
def _extract_referred_to_type(ctx: FunctionContext) -> Optional[Type]:
try:
referred_to_type = get_valid_to_value_or_none(ctx)
except helpers.InvalidModelString as exc:
ctx.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', ctx.context)
return None
return referred_to_type
def fill_descriptor_types_for_related_field(ctx: FunctionContext) -> Type:
default_return_type = set_descriptor_types_for_field(ctx)
referred_to_type = _extract_referred_to_type(ctx)
if referred_to_type is None:
return default_return_type
# replace Any with referred_to_type
args = []
for default_arg in default_return_type.args:
args.append(convert_any_to_type(default_arg, referred_to_type))
return helpers.reparametrize_instance(ctx.default_return_type, new_args=args)
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> Type:
node = type_info.get(private_field_name).node
if isinstance(node, Var):
descriptor_type = node.type
if is_nullable:
descriptor_type = helpers.make_optional(descriptor_type)
return descriptor_type
return AnyType(TypeOfAny.unannotated)
def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
default_return_type = cast(Instance, ctx.default_return_type)
is_nullable = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'null'))
set_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type',
is_nullable=is_nullable)
get_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_get_type',
is_nullable=is_nullable)
return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
default_return_type = set_descriptor_types_for_field(ctx)
base_field_arg_type = helpers.get_argument_type_by_name(ctx, 'base_field')
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
return default_return_type
base_type = base_field_arg_type.args[1] # extract __get__ type
args = []
for default_arg in default_return_type.args:
args.append(convert_any_to_type(default_arg, base_type))
return helpers.reparametrize_instance(default_return_type, args)
def transform_into_proper_return_type(ctx: FunctionContext) -> Type:
default_return_type = ctx.default_return_type
if not isinstance(default_return_type, Instance):
return default_return_type
if helpers.has_any_of_bases(default_return_type.type, (helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME,
helpers.MANYTOMANY_FIELD_FULLNAME)):
return fill_descriptor_types_for_related_field(ctx)
if default_return_type.type.has_base(helpers.ARRAY_FIELD_FULLNAME):
return determine_type_of_array_field(ctx)
return set_descriptor_types_for_field(ctx)
def adjust_return_type_of_field_instantiation(ctx: FunctionContext) -> Type:
record_field_properties_into_outer_model_class(ctx)
return transform_into_proper_return_type(ctx)
def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> None:
api = cast(TypeChecker, ctx.api)
outer_model = api.scope.active_class()
if outer_model is None or not outer_model.has_base(helpers.MODEL_CLASS_FULLNAME):
# outside models.Model class, undetermined
return
field_name = None
for name_expr, stmt in iter_over_assignments(outer_model.defn):
if stmt == ctx.context and isinstance(name_expr, NameExpr):
field_name = name_expr.name
break
if field_name is None:
return
fields_metadata = outer_model.metadata.setdefault('django', {}).setdefault('fields', {})
# primary key
is_primary_key = False
primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key')
if primary_key_arg:
is_primary_key = helpers.parse_bool(primary_key_arg)
fields_metadata[field_name] = {'primary_key': is_primary_key}
# choices
choices_arg = helpers.get_argument_by_name(ctx, 'choices')
if choices_arg and isinstance(choices_arg, (TupleExpr, ListExpr)):
# iterable of 2 element tuples of two kinds
_, analyzed_choices = api.analyze_iterable_item_type(choices_arg)
if isinstance(analyzed_choices, TupleType):
first_element_type = analyzed_choices.items[0]
if isinstance(first_element_type, Instance):
fields_metadata[field_name]['choices'] = first_element_type.type.fullname()
# nullability
null_arg = helpers.get_argument_by_name(ctx, 'null')
is_nullable = False
if null_arg:
is_nullable = helpers.parse_bool(null_arg)
fields_metadata[field_name]['null'] = is_nullable
# is_blankable
blank_arg = helpers.get_argument_by_name(ctx, 'blank')
is_blankable = False
if blank_arg:
is_blankable = helpers.parse_bool(blank_arg)
fields_metadata[field_name]['blank'] = is_blankable

View File

@@ -3,10 +3,10 @@ from typing import Dict, Optional, Set, cast
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo, Var
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType
from mypy.types import AnyType, Instance, Type, TypeOfAny
from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import extract_field_setter_type, extract_primary_key_type_for_set, get_fields_metadata
from mypy_django_plugin.helpers import extract_field_setter_type, extract_explicit_set_type_of_model_primary_key, get_fields_metadata
from mypy_django_plugin.transformers.fields import get_private_descriptor_type
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
@@ -112,40 +112,54 @@ def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
expected_types: Dict[str, Type] = {}
api = cast(TypeChecker, ctx.api)
primary_key_type = extract_primary_key_type_for_set(model)
expected_types: Dict[str, Type] = {}
primary_key_type = extract_explicit_set_type_of_model_primary_key(model)
if not primary_key_type:
# no explicit primary key, set pk to Any and add id
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
expected_types['pk'] = primary_key_type
for base in model.mro:
# extract all fields for all models in MRO
for name, sym in base.names.items():
# do not redefine special attrs
if name in {'_meta', 'pk'}:
continue
if isinstance(sym.node, Var):
if sym.node.type is None or isinstance(sym.node.type, AnyType):
typ = sym.node.type
if typ is None or isinstance(typ, AnyType):
# types are not ready, fallback to Any
expected_types[name] = AnyType(TypeOfAny.from_unimported_type)
expected_types[name + '_id'] = AnyType(TypeOfAny.from_unimported_type)
elif isinstance(sym.node.type, Instance):
tp = sym.node.type
field_type = extract_field_setter_type(tp)
elif isinstance(typ, Instance):
field_type = extract_field_setter_type(typ)
if field_type is None:
continue
if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
ref_to_model = tp.args[0]
primary_key_type = AnyType(TypeOfAny.special_form)
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
typ = extract_primary_key_type_for_set(ref_to_model.type)
if typ:
primary_key_type = typ
if typ.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
primary_key_type = AnyType(TypeOfAny.implementation_artifact)
# in case it's optional, we need Instance type
referred_to_model = typ.args[1]
is_nullable = helpers.is_optional(referred_to_model)
if is_nullable:
referred_to_model = helpers.make_required(typ.args[1])
if isinstance(referred_to_model, Instance) and referred_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
pk_type = extract_explicit_set_type_of_model_primary_key(referred_to_model.type)
if not pk_type:
# extract set type of AutoField
autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField')
pk_type = get_private_descriptor_type(autofield_info, '_pyi_private_set_type',
is_nullable=is_nullable)
primary_key_type = pk_type
expected_types[name + '_id'] = primary_key_type
if field_type:
expected_types[name] = field_type

View File

@@ -4,7 +4,6 @@ from mypy.checker import TypeChecker
from mypy.nodes import Expression, StrExpr, TypeInfo
from mypy.plugin import MethodContext
from mypy.types import Instance, Type, TypeType
from mypy_django_plugin import helpers

View File

@@ -73,8 +73,6 @@ def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExp
class SetIdAttrsForRelatedFields(ModelClassInitializer):
def run(self) -> None:
for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef):
# base_model_info = self.api.named_type('builtins.object').type
# helpers.get_related_field_primary_key_names(base_model_info).append(node_name)
node_name = lvalue.name + '_id'
self.add_new_node_to_model_class(name=node_name,
typ=self.api.builtin_type('builtins.int'))

View File

@@ -227,7 +227,7 @@ IGNORED_ERRORS = {
],
'postgres_tests': [
'Cannot assign multiple types to name',
'Incompatible types in assignment (expression has type "Type[Field]',
'Incompatible types in assignment (expression has type "Type[Field[Any, Any]]',
'DummyArrayField',
'DummyJSONField',
'Argument "encoder" to "JSONField" has incompatible type "DjangoJSONEncoder"; expected "Optional[Type[JSONEncoder]]"',

View File

@@ -1,4 +1,5 @@
[mypy]
incremental = False
incremental = True
strict_optional = True
plugins =
mypy_django_plugin.main

View File

@@ -6,7 +6,7 @@ class User(models.Model):
array = ArrayField(base_field=models.Field())
user = User()
reveal_type(user.array) # E: Revealed type is 'builtins.list[Any]'
reveal_type(user.array) # E: Revealed type is 'builtins.list*[Any]'
[CASE array_field_base_field_parsed_into_generic_typevar]
from django.db import models
@@ -17,8 +17,8 @@ class User(models.Model):
members_as_text = ArrayField(base_field=models.CharField(max_length=255))
user = User()
reveal_type(user.members) # E: Revealed type is 'builtins.list[builtins.int*]'
reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list[builtins.str*]'
reveal_type(user.members) # E: Revealed type is 'builtins.list*[builtins.int]'
reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list*[builtins.str]'
[CASE test_model_fields_classes_present_as_primitives]
from django.db import models
@@ -31,13 +31,13 @@ class User(models.Model):
text = models.TextField()
user = User()
reveal_type(user.id) # E: Revealed type is 'builtins.int'
reveal_type(user.small_int) # E: Revealed type is 'builtins.int'
reveal_type(user.name) # E: Revealed type is 'builtins.str'
reveal_type(user.slug) # E: Revealed type is 'builtins.str'
reveal_type(user.text) # E: Revealed type is 'builtins.str'
reveal_type(user.id) # E: Revealed type is 'builtins.int*'
reveal_type(user.small_int) # E: Revealed type is 'builtins.int*'
reveal_type(user.name) # E: Revealed type is 'builtins.str*'
reveal_type(user.slug) # E: Revealed type is 'builtins.str*'
reveal_type(user.text) # E: Revealed type is 'builtins.str*'
[CASE test_model_field_classes_from_exciting_locations]
[CASE test_model_field_classes_from_existing_locations]
from django.db import models
from django.contrib.postgres import fields as pg_fields
from decimal import Decimal
@@ -48,9 +48,9 @@ class Booking(models.Model):
some_decimal = models.DecimalField(max_digits=10, decimal_places=5)
booking = Booking()
reveal_type(booking.id) # E: Revealed type is 'builtins.int'
reveal_type(booking.id) # E: Revealed type is 'builtins.int*'
reveal_type(booking.time_range) # E: Revealed type is 'Any'
reveal_type(booking.some_decimal) # E: Revealed type is 'decimal.Decimal'
reveal_type(booking.some_decimal) # E: Revealed type is 'decimal.Decimal*'
[CASE test_add_id_field_if_no_primary_key_defined]
from django.db import models
@@ -66,7 +66,7 @@ from django.db import models
class User(models.Model):
my_pk = models.IntegerField(primary_key=True)
reveal_type(User().my_pk) # E: Revealed type is 'builtins.int'
reveal_type(User().my_pk) # E: Revealed type is 'builtins.int*'
reveal_type(User().id) # E: Revealed type is 'Any'
[out]
@@ -102,5 +102,5 @@ class ParentModel(models.Model):
pass
class MyModel(ParentModel):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
reveal_type(MyModel().id) # E: Revealed type is 'uuid.UUID'
reveal_type(MyModel().id) # E: Revealed type is 'uuid.UUID*'
[out]

View File

@@ -32,3 +32,14 @@ class Child1(Parent1, Parent2):
class Child4(Child1):
value4 = models.IntegerField()
Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)
[CASE create_through_related_manager_of_manytomany_field]
from django.db import models
class Publication(models.Model):
title = models.CharField(max_length=100)
class Article(models.Model):
publications = models.ManyToManyField(Publication)
article = Article.objects.create()
article.publications.create(title='my title')
[out]

View File

@@ -71,14 +71,15 @@ from django.db import models
class Publisher(models.Model):
pass
class PublisherWithCharPK(models.Model):
id = models.IntegerField(primary_key=True)
class PublisherDatetime(models.Model):
dt_pk = models.DateTimeField(primary_key=True)
class Book(models.Model):
publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE)
publisher_with_char_pk = models.ForeignKey(PublisherWithCharPK, on_delete=models.CASCADE)
publisher_dt = models.ForeignKey(PublisherDatetime, on_delete=models.CASCADE)
Book(publisher_id=1, publisher_with_char_pk_id=1)
Book(publisher_id=1, publisher_with_char_pk_id='hello') # E: Incompatible type for "publisher_with_char_pk_id" of "Book" (got "str", expected "Union[int, Combinable, Literal['']]")
Book(publisher_id=1)
Book(publisher_id=[]) # E: Incompatible type for "publisher_id" of "Book" (got "List[Any]", expected "Union[Combinable, int, str]")
Book(publisher_dt_id=11) # E: Incompatible type for "publisher_dt_id" of "Book" (got "int", expected "Union[str, date, datetime, Combinable]")
[out]
[CASE setting_value_to_an_array_of_ints]
@@ -94,7 +95,7 @@ MyModel(array=array_val)
array_val2: List[int] = [1]
MyModel(array=array_val2)
array_val3: List[str] = ['hello']
MyModel(array=array_val3) # E: Incompatible type for "array" of "MyModel" (got "List[str]", expected "Sequence[int]")
MyModel(array=array_val3) # E: Incompatible type for "array" of "MyModel" (got "List[str]", expected "Union[Sequence[int], Combinable]")
[out]
[CASE if_no_explicit_primary_key_id_can_be_passed]
@@ -135,20 +136,6 @@ Restaurant(place_ptr=place)
Restaurant(place_ptr_id=place.id)
[out]
[CASE extract_type_of_init_param_from_set_method]
from typing import Union
from datetime import time
from django.db import models
class MyField(models.Field):
def __set__(self, instance, value: Union[str, time]) -> None: pass
def __get__(self, instance, owner) -> time: pass
class MyModel(models.Model):
field = MyField()
MyModel(field=time())
MyModel(field='12:00')
MyModel(field=100) # E: Incompatible type for "field" of "MyModel" (got "int", expected "Union[str, time]")
[CASE charfield_with_integer_choices]
from django.db import models
class MyModel(models.Model):

View File

@@ -0,0 +1,42 @@
[CASE nullable_field_with_strict_optional_true]
from django.db import models
class MyModel(models.Model):
text_nullable = models.CharField(max_length=100, null=True)
text = models.CharField(max_length=100)
reveal_type(MyModel().text) # E: Revealed type is 'builtins.str*'
reveal_type(MyModel().text_nullable) # E: Revealed type is 'Union[builtins.str, None]'
MyModel().text = None # E: Incompatible types in assignment (expression has type "None", variable has type "Union[str, int, Combinable]")
MyModel().text_nullable = None
[out]
[CASE nullable_array_field]
from django.db import models
from django.contrib.postgres.fields import ArrayField
class MyModel(models.Model):
lst = ArrayField(base_field=models.CharField(max_length=100), null=True)
reveal_type(MyModel().lst) # E: Revealed type is 'Union[builtins.list[builtins.str], None]'
[out]
[CASE nullable_foreign_key]
from django.db import models
class Publisher(models.Model):
pass
class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE, null=True)
reveal_type(Book().publisher) # E: Revealed type is 'Union[main.Publisher, None]'
Book().publisher = 11 # E: Incompatible types in assignment (expression has type "int", variable has type "Union[Publisher, Combinable, None]")
[out]
[CASE nullable_self_foreign_key]
from django.db import models
class Inventory(models.Model):
parent = models.ForeignKey('self', on_delete=models.SET_NULL, null=True)
parent = Inventory()
core = Inventory(parent_id=parent.id)
reveal_type(core.parent_id) # E: Revealed type is 'Union[builtins.int, None]'
reveal_type(core.parent) # E: Revealed type is 'Union[main.Inventory, None]'
Inventory(parent=None)
Inventory(parent_id=None)
[out]

View File

@@ -198,7 +198,7 @@ class App(models.Model):
pass
class Member(models.Model):
apps = models.ManyToManyField(to=App, related_name='members')
reveal_type(Member().apps) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.App*]'
reveal_type(Member().apps) # E: Revealed type is 'django.db.models.manager.RelatedManager*[main.App]'
reveal_type(App().members) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Member]'
[out]
@@ -207,7 +207,7 @@ from django.db import models
from myapp.models import App
class Member(models.Model):
apps = models.ManyToManyField(to='myapp.App', related_name='members')
reveal_type(Member().apps) # E: Revealed type is 'django.db.models.manager.RelatedManager[myapp.models.App*]'
reveal_type(Member().apps) # E: Revealed type is 'django.db.models.manager.RelatedManager*[myapp.models.App]'
[file myapp/__init__.py]
[file myapp/models.py]
@@ -226,8 +226,8 @@ reveal_type(User().parent) # E: Revealed type is 'main.User*'
[CASE many_to_many_with_self]
from django.db import models
class User(models.Model):
friends = models.ManyToManyField('self', on_delete=models.CASCADE)
reveal_type(User().friends) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.User*]'
friends = models.ManyToManyField('self')
reveal_type(User().friends) # E: Revealed type is 'django.db.models.manager.RelatedManager*[main.User]'
[out]
[CASE recursively_checking_for_base_model_in_to_parameter]