From e9f9202ed1b4fc5ffc50434d9047ec184c939c88 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Sun, 17 Feb 2019 18:07:53 +0300 Subject: [PATCH] preliminary support for strict_optional --- django-stubs/contrib/admin/models.pyi | 2 +- django-stubs/contrib/auth/models.pyi | 8 +- django-stubs/contrib/contenttypes/fields.pyi | 2 +- .../contrib/postgres/fields/array.pyi | 45 +++- django-stubs/db/models/fields/__init__.pyi | 138 ++++++------ django-stubs/db/models/fields/related.pyi | 117 ++++++++-- mypy_django_plugin/helpers.py | 77 ++++--- mypy_django_plugin/main.py | 59 +++--- mypy_django_plugin/plugins/fields.py | 74 ------- mypy_django_plugin/plugins/related_fields.py | 62 ------ .../{plugins => transformers}/__init__.py | 0 mypy_django_plugin/transformers/fields.py | 199 ++++++++++++++++++ .../{plugins => transformers}/init_create.py | 48 +++-- .../{plugins => transformers}/migrations.py | 1 - .../{plugins => transformers}/models.py | 2 - .../{plugins => transformers}/settings.py | 0 scripts/typecheck_tests.py | 2 +- test-data/plugins.ini | 3 +- test-data/typecheck/fields.test | 26 +-- test-data/typecheck/model_create.test | 13 +- test-data/typecheck/model_init.test | 27 +-- test-data/typecheck/nullable_fields.test | 42 ++++ test-data/typecheck/related_fields.test | 10 +- 23 files changed, 614 insertions(+), 343 deletions(-) delete mode 100644 mypy_django_plugin/plugins/fields.py delete mode 100644 mypy_django_plugin/plugins/related_fields.py rename mypy_django_plugin/{plugins => transformers}/__init__.py (100%) create mode 100644 mypy_django_plugin/transformers/fields.py rename mypy_django_plugin/{plugins => transformers}/init_create.py (75%) rename mypy_django_plugin/{plugins => transformers}/migrations.py (99%) rename mypy_django_plugin/{plugins => transformers}/models.py (98%) rename mypy_django_plugin/{plugins => transformers}/settings.py (100%) create mode 100644 test-data/typecheck/nullable_fields.test diff --git a/django-stubs/contrib/admin/models.pyi b/django-stubs/contrib/admin/models.pyi index edfb37d..05c033a 100644 --- a/django-stubs/contrib/admin/models.pyi +++ b/django-stubs/contrib/admin/models.pyi @@ -25,7 +25,7 @@ class LogEntryManager(models.Manager["LogEntry"]): class LogEntry(models.Model): action_time: models.DateTimeField = ... user: models.ForeignKey = ... - content_type: models.ForeignKey[ContentType] = ... + content_type: models.ForeignKey = models.ForeignKey(ContentType, on_delete=models.CASCADE) object_id: models.TextField = ... object_repr: models.CharField = ... action_flag: models.PositiveSmallIntegerField = ... diff --git a/django-stubs/contrib/auth/models.pyi b/django-stubs/contrib/auth/models.pyi index 8147c2c..e539cb2 100644 --- a/django-stubs/contrib/auth/models.pyi +++ b/django-stubs/contrib/auth/models.pyi @@ -15,7 +15,7 @@ class PermissionManager(models.Manager): class Permission(models.Model): content_type_id: int name: models.CharField = ... - content_type: models.ForeignKey[ContentType] = ... + content_type: models.ForeignKey = models.ForeignKey(ContentType, on_delete=models.CASCADE) codename: models.CharField = ... def natural_key(self) -> Tuple[str, str, str]: ... @@ -24,7 +24,7 @@ class GroupManager(models.Manager): class Group(models.Model): name: models.CharField = ... - permissions: models.ManyToManyField[Permission] = ... + permissions: models.ManyToManyField = models.ManyToManyField(Permission) def natural_key(self): ... class UserManager(BaseUserManager): @@ -37,8 +37,8 @@ class UserManager(BaseUserManager): class PermissionsMixin(models.Model): is_superuser: models.BooleanField = ... - groups: models.ManyToManyField[Group] = ... - user_permissions: models.ManyToManyField[Permission] = ... + groups: models.ManyToManyField = models.ManyToManyField(Group) + user_permissions: models.ManyToManyField = models.ManyToManyField(Permission) def get_group_permissions(self, obj: None = ...) -> Set[str]: ... def get_all_permissions(self, obj: Optional[str] = ...) -> Set[str]: ... def has_perm(self, perm: Union[Tuple[str, Any], str], obj: Optional[str] = ...) -> bool: ... diff --git a/django-stubs/contrib/contenttypes/fields.pyi b/django-stubs/contrib/contenttypes/fields.pyi index 1717d6a..1b80e2e 100644 --- a/django-stubs/contrib/contenttypes/fields.pyi +++ b/django-stubs/contrib/contenttypes/fields.pyi @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, Generic from django.contrib.contenttypes.models import ContentType from django.core.checks.messages import Error diff --git a/django-stubs/contrib/postgres/fields/array.pyi b/django-stubs/contrib/postgres/fields/array.pyi index 6fa9488..a69b091 100644 --- a/django-stubs/contrib/postgres/fields/array.pyi +++ b/django-stubs/contrib/postgres/fields/array.pyi @@ -1,20 +1,51 @@ -from typing import Any, Generic, List, Optional, Sequence, TypeVar +from typing import Any, Iterable, List, Optional, Sequence, TypeVar, Union + +from django.db.models.expressions import Combinable +from django.db.models.fields import Field, _ErrorMessagesToOverride, _FieldChoices, _ValidatorCallable -from django.db.models.fields import Field from .mixins import CheckFieldDefaultMixin -_T = TypeVar("_T", bound=Field) +# __set__ value type +_ST = TypeVar("_ST") +# __get__ return type +_GT = TypeVar("_GT") + +class ArrayField(CheckFieldDefaultMixin, Field[_ST, _GT]): + _pyi_private_set_type: Union[Sequence[Any], Combinable] + _pyi_private_get_type: List[Any] -class ArrayField(CheckFieldDefaultMixin, Field, Generic[_T]): empty_strings_allowed: bool = ... default_error_messages: Any = ... base_field: Any = ... size: Any = ... default_validators: Any = ... from_db_value: Any = ... - def __init__(self, base_field: _T, size: Optional[int] = ..., **kwargs: Any) -> None: ... + def __init__( + self, + base_field: Field, + size: Optional[int] = ..., + verbose_name: Optional[Union[str, bytes]] = ..., + name: Optional[str] = ..., + primary_key: bool = ..., + max_length: Optional[int] = ..., + unique: bool = ..., + blank: bool = ..., + null: bool = ..., + db_index: bool = ..., + default: Any = ..., + editable: bool = ..., + auto_created: bool = ..., + serialize: bool = ..., + unique_for_date: Optional[str] = ..., + unique_for_month: Optional[str] = ..., + unique_for_year: Optional[str] = ..., + choices: Optional[_FieldChoices] = ..., + help_text: str = ..., + db_column: Optional[str] = ..., + db_tablespace: Optional[str] = ..., + validators: Iterable[_ValidatorCallable] = ..., + error_messages: Optional[_ErrorMessagesToOverride] = ..., + ) -> None: ... @property def description(self): ... def get_transform(self, name: Any): ... - def __set__(self, instance, value: Sequence[_T]) -> None: ... - def __get__(self, instance, owner) -> List[_T]: ... diff --git a/django-stubs/db/models/fields/__init__.pyi b/django-stubs/db/models/fields/__init__.pyi index 1ab2a2d..897ded2 100644 --- a/django-stubs/db/models/fields/__init__.pyi +++ b/django-stubs/db/models/fields/__init__.pyi @@ -1,16 +1,15 @@ -import uuid -from datetime import date, time, datetime, timedelta -from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar, Generic import decimal - -from typing_extensions import Literal +import uuid +from datetime import date, datetime, time, timedelta +from typing import Any, Callable, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union from django.db.models import Model -from django.db.models.query_utils import RegisterLookupMixin - -from django.db.models.expressions import F, Combinable from django.core.exceptions import FieldDoesNotExist as FieldDoesNotExist -from django.forms import Widget, Field as FormField +from django.db.models.expressions import Combinable +from django.db.models.query_utils import RegisterLookupMixin +from django.forms import Field as FormField, Widget +from typing_extensions import Literal + from .mixins import NOT_PROVIDED as NOT_PROVIDED _Choice = Tuple[Any, Any] @@ -20,7 +19,15 @@ _FieldChoices = Iterable[Union[_Choice, _ChoiceNamedGroup]] _ValidatorCallable = Callable[..., None] _ErrorMessagesToOverride = Dict[str, Any] -class Field(RegisterLookupMixin): +# __set__ value type +_ST = TypeVar("_ST") +# __get__ return type +_GT = TypeVar("_GT") + +class Field(RegisterLookupMixin, Generic[_ST, _GT]): + _pyi_private_set_type: Any + _pyi_private_get_type: Any + widget: Widget help_text: str db_table: str @@ -52,7 +59,8 @@ class Field(RegisterLookupMixin): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... - def __get__(self, instance, owner) -> Any: ... + def __set__(self, instance, value: _ST) -> None: ... + def __get__(self, instance, owner) -> _GT: ... def deconstruct(self) -> Any: ... def set_attributes_from_name(self, name: str) -> None: ... def db_type(self, connection: Any) -> str: ... @@ -63,23 +71,25 @@ class Field(RegisterLookupMixin): def contribute_to_class(self, cls: Type[Model], name: str, private_only: bool = ...) -> None: ... def to_python(self, value: Any) -> Any: ... -class IntegerField(Field): - def __set__(self, instance, value: Union[int, Combinable, Literal[""]]) -> None: ... - def __get__(self, instance, owner) -> int: ... +class IntegerField(Field[_ST, _GT]): + _pyi_private_set_type: Union[int, Combinable, Literal[""]] + _pyi_private_get_type: int class PositiveIntegerRelDbTypeMixin: def rel_db_type(self, connection: Any): ... -class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField): ... -class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField): ... -class SmallIntegerField(IntegerField): ... -class BigIntegerField(IntegerField): ... +class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ... +class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ... +class SmallIntegerField(IntegerField[_ST, _GT]): ... +class BigIntegerField(IntegerField[_ST, _GT]): ... -class FloatField(Field): - def __set__(self, instance, value: Union[float, int, str, Combinable]) -> float: ... - def __get__(self, instance, owner) -> float: ... +class FloatField(Field[_ST, _GT]): + _pyi_private_set_type: Union[float, int, str, Combinable] + _pyi_private_get_type: float -class DecimalField(Field): +class DecimalField(Field[_ST, _GT]): + _pyi_private_set_type: Union[str, float, decimal.Decimal, Combinable] + _pyi_private_get_type: decimal.Decimal def __init__( self, verbose_name: Optional[Union[str, bytes]] = ..., @@ -102,13 +112,14 @@ class DecimalField(Field): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... - def __set__(self, instance, value: Union[str, float, decimal.Decimal, Combinable]) -> decimal.Decimal: ... - def __get__(self, instance, owner) -> decimal.Decimal: ... -class AutoField(Field): - def __get__(self, instance, owner) -> int: ... +class AutoField(Field[_ST, _GT]): + _pyi_private_set_type: Union[Combinable, int, str] + _pyi_private_get_type: int -class CharField(Field): +class CharField(Field[_ST, _GT]): + _pyi_private_set_type: Union[str, int, Combinable] + _pyi_private_get_type: str def __init__( self, verbose_name: Optional[Union[str, bytes]] = ..., @@ -133,10 +144,8 @@ class CharField(Field): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... - def __set__(self, instance, value: Union[str, int, Combinable]) -> None: ... - def __get__(self, instance, owner) -> str: ... -class SlugField(CharField): +class SlugField(CharField[_ST, _GT]): def __init__( self, verbose_name: Optional[Union[str, bytes]] = ..., @@ -163,25 +172,29 @@ class SlugField(CharField): error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... -class EmailField(CharField): ... -class URLField(CharField): ... +class EmailField(CharField[_ST, _GT]): ... +class URLField(CharField[_ST, _GT]): ... -class TextField(Field): - def __set__(self, instance, value: Union[str, Combinable]) -> None: ... - def __get__(self, instance, owner) -> str: ... +class TextField(Field[_ST, _GT]): + _pyi_private_set_type: Union[str, Combinable] + _pyi_private_get_type: str -class BooleanField(Field): - def __set__(self, instance, value: Union[bool, Combinable]) -> None: ... - def __get__(self, instance, owner) -> bool: ... +class BooleanField(Field[_ST, _GT]): + _pyi_private_set_type: Union[bool, Combinable] + _pyi_private_get_type: bool -class NullBooleanField(Field): - def __set__(self, instance, value: Optional[Union[bool, Combinable]]) -> None: ... - def __get__(self, instance, owner) -> Optional[bool]: ... +class NullBooleanField(Field[_ST, _GT]): + _pyi_private_set_type: Optional[Union[bool, Combinable]] + _pyi_private_get_type: Optional[bool] -class IPAddressField(Field): - def __get__(self, instance, owner) -> str: ... +class IPAddressField(Field[_ST, _GT]): + _pyi_private_set_type: Union[str, Combinable] + _pyi_private_get_type: str + +class GenericIPAddressField(Field[_ST, _GT]): + _pyi_private_set_type: Union[str, int, Callable[..., Any], Combinable] + _pyi_private_get_type: str -class GenericIPAddressField(Field): default_error_messages: Any = ... unpack_ipv4: Any = ... protocol: Any = ... @@ -207,12 +220,12 @@ class GenericIPAddressField(Field): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ) -> None: ... - def __set__(self, instance, value: Union[str, int, Callable[..., Any], Combinable]): ... - def __get__(self, instance, owner) -> str: ... class DateTimeCheckMixin: ... -class DateField(DateTimeCheckMixin, Field): +class DateField(DateTimeCheckMixin, Field[_ST, _GT]): + _pyi_private_set_type: Union[str, date, datetime, Combinable] + _pyi_private_get_type: date def __init__( self, verbose_name: Optional[Union[str, bytes]] = ..., @@ -236,10 +249,10 @@ class DateField(DateTimeCheckMixin, Field): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... - def __set__(self, instance, value: Union[str, date, Combinable]) -> None: ... - def __get__(self, instance, owner) -> date: ... -class TimeField(DateTimeCheckMixin, Field): +class TimeField(DateTimeCheckMixin, Field[_ST, _GT]): + _pyi_private_set_type: Union[str, time, datetime, Combinable] + _pyi_private_get_type: time def __init__( self, verbose_name: Optional[Union[str, bytes]] = ..., @@ -262,18 +275,15 @@ class TimeField(DateTimeCheckMixin, Field): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... - def __set__(self, instance, value: Union[str, time, datetime, Combinable]) -> None: ... - def __get__(self, instance, owner) -> time: ... -class DateTimeField(DateField): - def __set__(self, instance, value: Union[str, date, datetime, Combinable]) -> None: ... - def __get__(self, instance, owner) -> datetime: ... +class DateTimeField(DateField[_ST, _GT]): + _pyi_private_get_type: datetime -class UUIDField(Field): - def __set__(self, instance, value: Union[str, uuid.UUID]) -> None: ... - def __get__(self, instance, owner) -> uuid.UUID: ... +class UUIDField(Field[_ST, _GT]): + _pyi_private_set_type: Union[str, uuid.UUID] + _pyi_private_get_type: uuid.UUID -class FilePathField(Field): +class FilePathField(Field[_ST, _GT]): path: str = ... match: Optional[Any] = ... recursive: bool = ... @@ -306,10 +316,10 @@ class FilePathField(Field): error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... -class BinaryField(Field): ... +class BinaryField(Field[_ST, _GT]): ... -class DurationField(Field): - def __get__(self, instance, owner) -> timedelta: ... +class DurationField(Field[_ST, _GT]): + _pyi_private_get_type: timedelta -class BigAutoField(AutoField): ... -class CommaSeparatedIntegerField(CharField): ... +class BigAutoField(AutoField[_ST, _GT]): ... +class CommaSeparatedIntegerField(CharField[_ST, _GT]): ... diff --git a/django-stubs/db/models/fields/related.pyi b/django-stubs/db/models/fields/related.pyi index e9b8d9f..d990e00 100644 --- a/django-stubs/db/models/fields/related.pyi +++ b/django-stubs/db/models/fields/related.pyi @@ -49,7 +49,12 @@ _ErrorMessagesToOverride = Dict[str, Any] RECURSIVE_RELATIONSHIP_CONSTANT: str = ... -class RelatedField(FieldCacheMixin, Field): +# __set__ value type +_ST = TypeVar("_ST") +# __get__ return type +_GT = TypeVar("_GT") + +class RelatedField(FieldCacheMixin, Field[_ST, _GT]): one_to_many: bool = ... one_to_one: bool = ... many_to_many: bool = ... @@ -83,6 +88,7 @@ class ForeignObject(RelatedField): related_query_name: None = ..., limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any]]] = ..., parent_link: bool = ..., + db_constraint: bool = ..., swappable: bool = ..., verbose_name: Optional[str] = ..., name: Optional[str] = ..., @@ -103,17 +109,82 @@ class ForeignObject(RelatedField): error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... -class ForeignKey(RelatedField, Generic[_T]): - def __init__(self, to: Union[Type[_T], str], on_delete: Any, related_name: str = ..., **kwargs): ... - def __set__(self, instance, value: Union[Model, Combinable]) -> None: ... - def __get__(self, instance, owner) -> _T: ... +class ForeignKey(RelatedField[_ST, _GT]): + _pyi_private_set_type: Union[Any, Combinable] + _pyi_private_get_type: Any + def __init__( + self, + to: Union[Type[Model], str], + on_delete: Callable[..., None], + to_field: Optional[str] = ..., + related_name: str = ..., + related_query_name: Optional[str] = ..., + limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any], Q]] = ..., + parent_link: bool = ..., + db_constraint: bool = ..., + verbose_name: Optional[Union[str, bytes]] = ..., + name: Optional[str] = ..., + primary_key: bool = ..., + max_length: Optional[int] = ..., + unique: bool = ..., + blank: bool = ..., + null: bool = ..., + db_index: bool = ..., + default: Any = ..., + editable: bool = ..., + auto_created: bool = ..., + serialize: bool = ..., + unique_for_date: Optional[str] = ..., + unique_for_month: Optional[str] = ..., + unique_for_year: Optional[str] = ..., + choices: Optional[_FieldChoices] = ..., + help_text: str = ..., + db_column: Optional[str] = ..., + db_tablespace: Optional[str] = ..., + validators: Iterable[_ValidatorCallable] = ..., + error_messages: Optional[_ErrorMessagesToOverride] = ..., + ): ... -class OneToOneField(RelatedField, Generic[_T]): - def __init__(self, to: Union[Type[_T], str], on_delete: Any, related_name: str = ..., **kwargs): ... - def __set__(self, instance, value: Union[Model, Combinable]) -> None: ... - def __get__(self, instance, owner) -> _T: ... +class OneToOneField(RelatedField[_ST, _GT]): + _pyi_private_set_type: Union[Any, Combinable] + _pyi_private_get_type: Any + def __init__( + self, + to: Union[Type[Model], str], + on_delete: Any, + to_field: Optional[str] = ..., + related_name: str = ..., + related_query_name: Optional[str] = ..., + limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any], Q]] = ..., + parent_link: bool = ..., + db_constraint: bool = ..., + verbose_name: Optional[Union[str, bytes]] = ..., + name: Optional[str] = ..., + primary_key: bool = ..., + max_length: Optional[int] = ..., + unique: bool = ..., + blank: bool = ..., + null: bool = ..., + db_index: bool = ..., + default: Any = ..., + editable: bool = ..., + auto_created: bool = ..., + serialize: bool = ..., + unique_for_date: Optional[str] = ..., + unique_for_month: Optional[str] = ..., + unique_for_year: Optional[str] = ..., + choices: Optional[_FieldChoices] = ..., + help_text: str = ..., + db_column: Optional[str] = ..., + db_tablespace: Optional[str] = ..., + validators: Iterable[_ValidatorCallable] = ..., + error_messages: Optional[_ErrorMessagesToOverride] = ..., + ): ... + +class ManyToManyField(RelatedField[_ST, _GT]): + _pyi_private_set_type: Sequence[Any] + _pyi_private_get_type: RelatedManager[Any] -class ManyToManyField(RelatedField, Generic[_T]): many_to_many: bool = ... many_to_one: bool = ... one_to_many: bool = ... @@ -127,17 +198,35 @@ class ManyToManyField(RelatedField, Generic[_T]): to: Union[Type[_T], str], related_name: Optional[str] = ..., related_query_name: Optional[str] = ..., - limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any]]] = ..., + limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any], Q]] = ..., symmetrical: Optional[bool] = ..., through: Optional[Union[str, Type[Model]]] = ..., through_fields: Optional[Tuple[str, str]] = ..., db_constraint: bool = ..., db_table: Optional[str] = ..., swappable: bool = ..., - **kwargs: Any + verbose_name: Optional[Union[str, bytes]] = ..., + name: Optional[str] = ..., + primary_key: bool = ..., + max_length: Optional[int] = ..., + unique: bool = ..., + blank: bool = ..., + null: bool = ..., + db_index: bool = ..., + default: Any = ..., + editable: bool = ..., + auto_created: bool = ..., + serialize: bool = ..., + unique_for_date: Optional[str] = ..., + unique_for_month: Optional[str] = ..., + unique_for_year: Optional[str] = ..., + choices: Optional[_FieldChoices] = ..., + help_text: str = ..., + db_column: Optional[str] = ..., + db_tablespace: Optional[str] = ..., + validators: Iterable[_ValidatorCallable] = ..., + error_messages: Optional[_ErrorMessagesToOverride] = ..., ) -> None: ... - def __set__(self, instance, value: Sequence[_T]) -> None: ... - def __get__(self, instance, owner) -> RelatedManager[_T]: ... def check(self, **kwargs: Any) -> List[Any]: ... def deconstruct(self) -> Tuple[Optional[str], str, List[Any], Dict[str, str]]: ... def get_path_info(self, filtered_relation: None = ...) -> List[PathInfo]: ... diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index 0c28f32..ff51fcb 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -5,10 +5,12 @@ from mypy.checker import TypeChecker from mypy.nodes import AssignmentStmt, ClassDef, Expression, FuncDef, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, \ TypeInfo from mypy.plugin import FunctionContext -from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType +from mypy.types import AnyType, CallableType, Instance, NoneTyp, Type, TypeOfAny, TypeVarType, UnionType MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' FIELD_FULLNAME = 'django.db.models.fields.Field' +ARRAY_FIELD_FULLNAME = 'django.contrib.postgres.fields.array.ArrayField' +AUTO_FIELD_FULLNAME = 'django.db.models.fields.AutoField' GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey' FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey' ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField' @@ -95,12 +97,13 @@ def parse_bool(expr: Expression) -> Optional[bool]: return None -def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]): - return Instance(instance.type, args=new_typevars) +def reparametrize_instance(instance: Instance, new_args: typing.List[Type]) -> Instance: + return Instance(instance.type, args=new_args, + line=instance.line, column=instance.column) -def fill_typevars_with_any(instance: Instance) -> Type: - return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)]) +def fill_typevars_with_any(instance: Instance) -> Instance: + return reparametrize_instance(instance, [AnyType(TypeOfAny.unannotated)]) def extract_typevar_value(tp: Instance, typevar_name: str) -> Type: @@ -117,7 +120,7 @@ def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance: for typevar_arg in type_to_fill.args: if isinstance(typevar_arg, TypeVarType): typevar_values.append(extract_typevar_value(tp, typevar_arg.name)) - return reparametrize_with(type_to_fill, typevar_values) + return Instance(type_to_fill.type, typevar_values) def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]: @@ -189,28 +192,12 @@ def iter_over_assignments( def extract_field_setter_type(tp: Instance) -> Optional[Type]: - if not isinstance(tp, Instance): - return None + """ Extract __set__ value of a field. """ if tp.type.has_base(FIELD_FULLNAME): - set_method = tp.type.get_method('__set__') - if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType): - if 'value' in set_method.type.arg_names: - set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')] - if isinstance(set_value_type, Instance): - set_value_type = fill_typevars(tp, set_value_type) - return set_value_type - elif isinstance(set_value_type, UnionType): - items_no_typevars = [] - for item in set_value_type.items: - if isinstance(item, Instance): - item = fill_typevars(tp, item) - items_no_typevars.append(item) - return UnionType(items_no_typevars) - - field_getter_type = extract_field_getter_type(tp) - if field_getter_type: - return field_getter_type - + return tp.args[0] + # GenericForeignKey + if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME): + return AnyType(TypeOfAny.special_form) return None @@ -218,9 +205,7 @@ def extract_field_getter_type(tp: Instance) -> Optional[Type]: if not isinstance(tp, Instance): return None if tp.type.has_base(FIELD_FULLNAME): - get_method = tp.type.get_method('__get__') - if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType): - return get_method.type.ret_type + return tp.args[1] # GenericForeignKey if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME): return AnyType(TypeOfAny.special_form) @@ -240,7 +225,10 @@ def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]: return get_django_metadata(model).setdefault('fields', {}) -def extract_primary_key_type_for_set(model: TypeInfo) -> Optional[Type]: +def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[Type]: + """ + If field with primary_key=True is set on the model, extract its __set__ type. + """ for field_name, props in get_fields_metadata(model).items(): is_primary_key = props.get('primary_key', False) if is_primary_key: @@ -254,3 +242,30 @@ def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]: if is_primary_key: return extract_field_getter_type(model.names[field_name].type) return None + + +def make_optional(typ: Type): + return UnionType.make_simplified_union([typ, NoneTyp()]) + + +def make_required(typ: Type) -> Type: + if not isinstance(typ, UnionType): + return typ + items = [item for item in typ.items if not isinstance(item, NoneTyp)] + # will reduce to Instance, if only one item + return UnionType.make_union(items) + + +def is_optional(typ: Type) -> bool: + if not isinstance(typ, UnionType): + return False + + return any([isinstance(item, NoneTyp) for item in typ.items]) + + + +def has_any_of_bases(info: TypeInfo, bases: typing.Sequence[str]) -> bool: + for base_fullname in bases: + if info.has_base(base_fullname): + return True + return False diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index ff40ec6..a2c87a6 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,19 +1,18 @@ import os -from typing import Callable, Dict, Optional, cast +from typing import Callable, Dict, Optional, Union, cast from mypy.checker import TypeChecker from mypy.nodes import MemberExpr, TypeInfo from mypy.options import Options from mypy.plugin import AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin -from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType +from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType, UnionType from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin.config import Config -from mypy_django_plugin.plugins import init_create -from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class -from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations, get_string_value_from_expr -from mypy_django_plugin.plugins.models import process_model_class -from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with -from mypy_django_plugin.plugins.settings import AddSettingValuesToDjangoConfObject, get_settings_metadata +from mypy_django_plugin.transformers import fields, init_create +from mypy_django_plugin.transformers.migrations import determine_model_cls_from_string_for_migrations, \ + get_string_value_from_expr +from mypy_django_plugin.transformers.models import process_model_class +from mypy_django_plugin.transformers.settings import AddSettingValuesToDjangoConfObject, get_settings_metadata def transform_model_class(ctx: ClassDefContext) -> None: @@ -50,7 +49,7 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type: if base.type.fullname() in {helpers.MANAGER_CLASS_FULLNAME, helpers.RELATED_MANAGER_CLASS_FULLNAME, helpers.BASE_MANAGER_CLASS_FULLNAME}: - ret.type.bases[i] = reparametrize_with(base, [Instance(outer_model_info, [])]) + ret.type.bases[i] = Instance(base.type, [Instance(outer_model_info, [])]) return ret return ret @@ -84,6 +83,17 @@ def return_user_model_hook(ctx: FunctionContext) -> Type: return TypeType(Instance(model_info, [])) +def _extract_referred_to_type_info(typ: Union[UnionType, Instance]) -> Optional[TypeInfo]: + if isinstance(typ, Instance): + return typ.type + else: + # should be Union[TYPE, None] + typ = helpers.make_required(typ) + if isinstance(typ, Instance): + return typ.type + return None + + def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type: if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'): return ctx.default_attr_type @@ -94,14 +104,22 @@ def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: Attribu field_name = ctx.context.name.split('_')[0] sym = ctx.type.type.get(field_name) if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0: - to_arg = sym.type.args[0] - if isinstance(to_arg, AnyType): - return AnyType(TypeOfAny.special_form) + referred_to = sym.type.args[1] + if isinstance(referred_to, AnyType): + return AnyType(TypeOfAny.implementation_artifact) + + model_type = _extract_referred_to_type_info(referred_to) + if model_type is None: + return AnyType(TypeOfAny.implementation_artifact) - model_type: TypeInfo = to_arg.type primary_key_type = helpers.extract_primary_key_type_for_get(model_type) if primary_key_type: return primary_key_type + + is_nullable = helpers.get_fields_metadata(ctx.type.type).get(field_name, {}).get('null', False) + if is_nullable: + return helpers.make_optional(ctx.default_attr_type) + return ctx.default_attr_type @@ -179,26 +197,19 @@ class DjangoPlugin(Plugin): def get_function_hook(self, fullname: str ) -> Optional[Callable[[FunctionContext], Type]]: + sym = self.lookup_fully_qualified(fullname) + if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FIELD_FULLNAME): + return fields.adjust_return_type_of_field_instantiation + if fullname == 'django.contrib.auth.get_user_model': return return_user_model_hook - if fullname in {helpers.FOREIGN_KEY_FULLNAME, - helpers.ONETOONE_FIELD_FULLNAME, - helpers.MANYTOMANY_FIELD_FULLNAME}: - return extract_to_parameter_as_get_ret_type_for_related_field - - if fullname == 'django.contrib.postgres.fields.array.ArrayField': - return determine_type_of_array_field - manager_bases = self._get_current_manager_bases() if fullname in manager_bases: return determine_proper_manager_type sym = self.lookup_fully_qualified(fullname) if sym and isinstance(sym.node, TypeInfo): - if sym.node.has_base(helpers.FIELD_FULLNAME): - return record_field_properties_into_outer_model_class - if sym.node.metadata.get('django', {}).get('generated_init'): return init_create.redefine_and_typecheck_model_init diff --git a/mypy_django_plugin/plugins/fields.py b/mypy_django_plugin/plugins/fields.py deleted file mode 100644 index 2a3caab..0000000 --- a/mypy_django_plugin/plugins/fields.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import cast - -from mypy.checker import TypeChecker -from mypy.nodes import ListExpr, NameExpr, TupleExpr -from mypy.plugin import FunctionContext -from mypy.types import Instance, TupleType, Type - -from mypy_django_plugin import helpers -from mypy_django_plugin.plugins.models import iter_over_assignments - - -def determine_type_of_array_field(ctx: FunctionContext) -> Type: - base_field_arg_type = helpers.get_argument_type_by_name(ctx, 'base_field') - if not base_field_arg_type or not isinstance(base_field_arg_type, Instance): - return ctx.default_return_type - - get_method = base_field_arg_type.type.get_method('__get__') - if not get_method: - # not a method - return ctx.default_return_type - - return ctx.api.named_generic_type(ctx.context.callee.fullname, - args=[get_method.type.ret_type]) - - -def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> Type: - api = cast(TypeChecker, ctx.api) - outer_model = api.scope.active_class() - if outer_model is None or not outer_model.has_base(helpers.MODEL_CLASS_FULLNAME): - # outside models.Model class, undetermined - return ctx.default_return_type - - field_name = None - for name_expr, stmt in iter_over_assignments(outer_model.defn): - if stmt == ctx.context and isinstance(name_expr, NameExpr): - field_name = name_expr.name - break - if field_name is None: - return ctx.default_return_type - - fields_metadata = outer_model.metadata.setdefault('django', {}).setdefault('fields', {}) - - # primary key - is_primary_key = False - primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key') - if primary_key_arg: - is_primary_key = helpers.parse_bool(primary_key_arg) - fields_metadata[field_name] = {'primary_key': is_primary_key} - - # choices - choices_arg = helpers.get_argument_by_name(ctx, 'choices') - if choices_arg and isinstance(choices_arg, (TupleExpr, ListExpr)): - # iterable of 2 element tuples of two kinds - _, analyzed_choices = api.analyze_iterable_item_type(choices_arg) - if isinstance(analyzed_choices, TupleType): - first_element_type = analyzed_choices.items[0] - if isinstance(first_element_type, Instance): - fields_metadata[field_name]['choices'] = first_element_type.type.fullname() - - # nullability - null_arg = helpers.get_argument_by_name(ctx, 'null') - is_nullable = False - if null_arg: - is_nullable = helpers.parse_bool(null_arg) - fields_metadata[field_name]['null'] = is_nullable - - # is_blankable - blank_arg = helpers.get_argument_by_name(ctx, 'blank') - is_blankable = False - if blank_arg: - is_blankable = helpers.parse_bool(blank_arg) - fields_metadata[field_name]['blank'] = is_blankable - - return ctx.default_return_type diff --git a/mypy_django_plugin/plugins/related_fields.py b/mypy_django_plugin/plugins/related_fields.py deleted file mode 100644 index 2fd0aac..0000000 --- a/mypy_django_plugin/plugins/related_fields.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Optional, cast - -from mypy.checker import TypeChecker -from mypy.nodes import StrExpr, TypeInfo -from mypy.plugin import FunctionContext -from mypy.types import CallableType, Instance, Type - -from mypy_django_plugin import helpers -from mypy_django_plugin.helpers import fill_typevars_with_any, reparametrize_with - - -def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: - api = cast(TypeChecker, ctx.api) - if 'to' not in ctx.callee_arg_names: - # shouldn't happen, invalid code - api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}', - context=ctx.context) - return None - - arg_type = ctx.arg_types[ctx.callee_arg_names.index('to')][0] - if not isinstance(arg_type, CallableType): - to_arg_expr = ctx.args[ctx.callee_arg_names.index('to')][0] - if not isinstance(to_arg_expr, StrExpr): - # not string, not supported - return None - try: - model_fullname = helpers.get_model_fullname_from_string(to_arg_expr.value, - all_modules=api.modules) - except helpers.SelfReference: - model_fullname = api.tscope.classes[-1].fullname() - - if model_fullname is None: - return None - model_info = helpers.lookup_fully_qualified_generic(model_fullname, - all_modules=api.modules) - if model_info is None or not isinstance(model_info, TypeInfo): - return None - return Instance(model_info, []) - - referred_to_type = arg_type.ret_type - if not isinstance(referred_to_type, Instance): - return None - if not referred_to_type.type.has_base(helpers.MODEL_CLASS_FULLNAME): - ctx.api.msg.fail(f'to= parameter value must be ' - f'a subclass of {helpers.MODEL_CLASS_FULLNAME}', - context=ctx.context) - return None - - return referred_to_type - - -def extract_to_parameter_as_get_ret_type_for_related_field(ctx: FunctionContext) -> Type: - try: - referred_to_type = get_valid_to_value_or_none(ctx) - except helpers.InvalidModelString as exc: - ctx.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', ctx.context) - return fill_typevars_with_any(ctx.default_return_type) - - if referred_to_type is None: - # couldn't extract to= value - return fill_typevars_with_any(ctx.default_return_type) - return reparametrize_with(ctx.default_return_type, [referred_to_type]) diff --git a/mypy_django_plugin/plugins/__init__.py b/mypy_django_plugin/transformers/__init__.py similarity index 100% rename from mypy_django_plugin/plugins/__init__.py rename to mypy_django_plugin/transformers/__init__.py diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py new file mode 100644 index 0000000..2b2a880 --- /dev/null +++ b/mypy_django_plugin/transformers/fields.py @@ -0,0 +1,199 @@ +from typing import Optional, cast + +from mypy.checker import TypeChecker +from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo, Var +from mypy.plugin import FunctionContext +from mypy.types import AnyType, CallableType, Instance, TupleType, Type, TypeOfAny, UnionType +from mypy_django_plugin import helpers +from mypy_django_plugin.transformers.models import iter_over_assignments + + +def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: + api = cast(TypeChecker, ctx.api) + if 'to' not in ctx.callee_arg_names: + api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}', + context=ctx.context) + return None + + arg_type = ctx.arg_types[ctx.callee_arg_names.index('to')][0] + if not isinstance(arg_type, CallableType): + to_arg_expr = ctx.args[ctx.callee_arg_names.index('to')][0] + if not isinstance(to_arg_expr, StrExpr): + # not string, not supported + return None + try: + model_fullname = helpers.get_model_fullname_from_string(to_arg_expr.value, + all_modules=api.modules) + except helpers.SelfReference: + model_fullname = api.tscope.classes[-1].fullname() + + if model_fullname is None: + return None + model_info = helpers.lookup_fully_qualified_generic(model_fullname, + all_modules=api.modules) + if model_info is None or not isinstance(model_info, TypeInfo): + return None + return Instance(model_info, []) + + referred_to_type = arg_type.ret_type + if not isinstance(referred_to_type, Instance): + return None + if not referred_to_type.type.has_base(helpers.MODEL_CLASS_FULLNAME): + ctx.api.msg.fail(f'to= parameter value must be ' + f'a subclass of {helpers.MODEL_CLASS_FULLNAME}', + context=ctx.context) + return None + + return referred_to_type + + +def convert_any_to_type(typ: Type, referred_to_type: Type) -> Type: + if isinstance(typ, UnionType): + converted_items = [] + for item in typ.items: + converted_items.append(convert_any_to_type(item, referred_to_type)) + return UnionType.make_simplified_union(converted_items, + line=typ.line, column=typ.column) + if isinstance(typ, Instance): + args = [] + for default_arg in typ.args: + if isinstance(default_arg, AnyType): + args.append(referred_to_type) + else: + args.append(default_arg) + return helpers.reparametrize_instance(typ, args) + + if isinstance(typ, AnyType): + return referred_to_type + + return typ + + +def _extract_referred_to_type(ctx: FunctionContext) -> Optional[Type]: + try: + referred_to_type = get_valid_to_value_or_none(ctx) + except helpers.InvalidModelString as exc: + ctx.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', ctx.context) + return None + + return referred_to_type + + +def fill_descriptor_types_for_related_field(ctx: FunctionContext) -> Type: + default_return_type = set_descriptor_types_for_field(ctx) + referred_to_type = _extract_referred_to_type(ctx) + if referred_to_type is None: + return default_return_type + + # replace Any with referred_to_type + args = [] + for default_arg in default_return_type.args: + args.append(convert_any_to_type(default_arg, referred_to_type)) + + return helpers.reparametrize_instance(ctx.default_return_type, new_args=args) + + +def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> Type: + node = type_info.get(private_field_name).node + if isinstance(node, Var): + descriptor_type = node.type + if is_nullable: + descriptor_type = helpers.make_optional(descriptor_type) + return descriptor_type + return AnyType(TypeOfAny.unannotated) + + +def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance: + default_return_type = cast(Instance, ctx.default_return_type) + is_nullable = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'null')) + + set_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type', + is_nullable=is_nullable) + get_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_get_type', + is_nullable=is_nullable) + return helpers.reparametrize_instance(default_return_type, [set_type, get_type]) + + +def determine_type_of_array_field(ctx: FunctionContext) -> Type: + default_return_type = set_descriptor_types_for_field(ctx) + + base_field_arg_type = helpers.get_argument_type_by_name(ctx, 'base_field') + if not base_field_arg_type or not isinstance(base_field_arg_type, Instance): + return default_return_type + + base_type = base_field_arg_type.args[1] # extract __get__ type + args = [] + for default_arg in default_return_type.args: + args.append(convert_any_to_type(default_arg, base_type)) + + return helpers.reparametrize_instance(default_return_type, args) + + +def transform_into_proper_return_type(ctx: FunctionContext) -> Type: + default_return_type = ctx.default_return_type + if not isinstance(default_return_type, Instance): + return default_return_type + + if helpers.has_any_of_bases(default_return_type.type, (helpers.FOREIGN_KEY_FULLNAME, + helpers.ONETOONE_FIELD_FULLNAME, + helpers.MANYTOMANY_FIELD_FULLNAME)): + return fill_descriptor_types_for_related_field(ctx) + + if default_return_type.type.has_base(helpers.ARRAY_FIELD_FULLNAME): + return determine_type_of_array_field(ctx) + + return set_descriptor_types_for_field(ctx) + + +def adjust_return_type_of_field_instantiation(ctx: FunctionContext) -> Type: + record_field_properties_into_outer_model_class(ctx) + return transform_into_proper_return_type(ctx) + + +def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> None: + api = cast(TypeChecker, ctx.api) + outer_model = api.scope.active_class() + if outer_model is None or not outer_model.has_base(helpers.MODEL_CLASS_FULLNAME): + # outside models.Model class, undetermined + return + + field_name = None + for name_expr, stmt in iter_over_assignments(outer_model.defn): + if stmt == ctx.context and isinstance(name_expr, NameExpr): + field_name = name_expr.name + break + if field_name is None: + return + + fields_metadata = outer_model.metadata.setdefault('django', {}).setdefault('fields', {}) + + # primary key + is_primary_key = False + primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key') + if primary_key_arg: + is_primary_key = helpers.parse_bool(primary_key_arg) + fields_metadata[field_name] = {'primary_key': is_primary_key} + + # choices + choices_arg = helpers.get_argument_by_name(ctx, 'choices') + if choices_arg and isinstance(choices_arg, (TupleExpr, ListExpr)): + # iterable of 2 element tuples of two kinds + _, analyzed_choices = api.analyze_iterable_item_type(choices_arg) + if isinstance(analyzed_choices, TupleType): + first_element_type = analyzed_choices.items[0] + if isinstance(first_element_type, Instance): + fields_metadata[field_name]['choices'] = first_element_type.type.fullname() + + # nullability + null_arg = helpers.get_argument_by_name(ctx, 'null') + is_nullable = False + if null_arg: + is_nullable = helpers.parse_bool(null_arg) + fields_metadata[field_name]['null'] = is_nullable + + # is_blankable + blank_arg = helpers.get_argument_by_name(ctx, 'blank') + is_blankable = False + if blank_arg: + is_blankable = helpers.parse_bool(blank_arg) + fields_metadata[field_name]['blank'] = is_blankable diff --git a/mypy_django_plugin/plugins/init_create.py b/mypy_django_plugin/transformers/init_create.py similarity index 75% rename from mypy_django_plugin/plugins/init_create.py rename to mypy_django_plugin/transformers/init_create.py index d963d7f..5635e8b 100644 --- a/mypy_django_plugin/plugins/init_create.py +++ b/mypy_django_plugin/transformers/init_create.py @@ -3,10 +3,10 @@ from typing import Dict, Optional, Set, cast from mypy.checker import TypeChecker from mypy.nodes import TypeInfo, Var from mypy.plugin import FunctionContext, MethodContext -from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType - +from mypy.types import AnyType, Instance, Type, TypeOfAny from mypy_django_plugin import helpers -from mypy_django_plugin.helpers import extract_field_setter_type, extract_primary_key_type_for_set, get_fields_metadata +from mypy_django_plugin.helpers import extract_field_setter_type, extract_explicit_set_type_of_model_primary_key, get_fields_metadata +from mypy_django_plugin.transformers.fields import get_private_descriptor_type def extract_base_pointer_args(model: TypeInfo) -> Set[str]: @@ -112,40 +112,54 @@ def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]: def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]: - expected_types: Dict[str, Type] = {} + api = cast(TypeChecker, ctx.api) - primary_key_type = extract_primary_key_type_for_set(model) + expected_types: Dict[str, Type] = {} + primary_key_type = extract_explicit_set_type_of_model_primary_key(model) if not primary_key_type: # no explicit primary key, set pk to Any and add id primary_key_type = AnyType(TypeOfAny.special_form) expected_types['id'] = ctx.api.named_generic_type('builtins.int', []) - expected_types['pk'] = primary_key_type + for base in model.mro: + # extract all fields for all models in MRO for name, sym in base.names.items(): # do not redefine special attrs if name in {'_meta', 'pk'}: continue + if isinstance(sym.node, Var): - if sym.node.type is None or isinstance(sym.node.type, AnyType): + typ = sym.node.type + if typ is None or isinstance(typ, AnyType): # types are not ready, fallback to Any expected_types[name] = AnyType(TypeOfAny.from_unimported_type) expected_types[name + '_id'] = AnyType(TypeOfAny.from_unimported_type) - elif isinstance(sym.node.type, Instance): - tp = sym.node.type - field_type = extract_field_setter_type(tp) + elif isinstance(typ, Instance): + field_type = extract_field_setter_type(typ) if field_type is None: continue - if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}: - ref_to_model = tp.args[0] - primary_key_type = AnyType(TypeOfAny.special_form) - if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME): - typ = extract_primary_key_type_for_set(ref_to_model.type) - if typ: - primary_key_type = typ + if typ.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}: + primary_key_type = AnyType(TypeOfAny.implementation_artifact) + # in case it's optional, we need Instance type + referred_to_model = typ.args[1] + is_nullable = helpers.is_optional(referred_to_model) + if is_nullable: + referred_to_model = helpers.make_required(typ.args[1]) + + if isinstance(referred_to_model, Instance) and referred_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME): + pk_type = extract_explicit_set_type_of_model_primary_key(referred_to_model.type) + if not pk_type: + # extract set type of AutoField + autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField') + pk_type = get_private_descriptor_type(autofield_info, '_pyi_private_set_type', + is_nullable=is_nullable) + primary_key_type = pk_type + expected_types[name + '_id'] = primary_key_type + if field_type: expected_types[name] = field_type diff --git a/mypy_django_plugin/plugins/migrations.py b/mypy_django_plugin/transformers/migrations.py similarity index 99% rename from mypy_django_plugin/plugins/migrations.py rename to mypy_django_plugin/transformers/migrations.py index b6baad8..7371f67 100644 --- a/mypy_django_plugin/plugins/migrations.py +++ b/mypy_django_plugin/transformers/migrations.py @@ -4,7 +4,6 @@ from mypy.checker import TypeChecker from mypy.nodes import Expression, StrExpr, TypeInfo from mypy.plugin import MethodContext from mypy.types import Instance, Type, TypeType - from mypy_django_plugin import helpers diff --git a/mypy_django_plugin/plugins/models.py b/mypy_django_plugin/transformers/models.py similarity index 98% rename from mypy_django_plugin/plugins/models.py rename to mypy_django_plugin/transformers/models.py index ff2ac07..8f5ce72 100644 --- a/mypy_django_plugin/plugins/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -73,8 +73,6 @@ def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExp class SetIdAttrsForRelatedFields(ModelClassInitializer): def run(self) -> None: for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef): - # base_model_info = self.api.named_type('builtins.object').type - # helpers.get_related_field_primary_key_names(base_model_info).append(node_name) node_name = lvalue.name + '_id' self.add_new_node_to_model_class(name=node_name, typ=self.api.builtin_type('builtins.int')) diff --git a/mypy_django_plugin/plugins/settings.py b/mypy_django_plugin/transformers/settings.py similarity index 100% rename from mypy_django_plugin/plugins/settings.py rename to mypy_django_plugin/transformers/settings.py diff --git a/scripts/typecheck_tests.py b/scripts/typecheck_tests.py index 237e881..9efcbee 100644 --- a/scripts/typecheck_tests.py +++ b/scripts/typecheck_tests.py @@ -227,7 +227,7 @@ IGNORED_ERRORS = { ], 'postgres_tests': [ 'Cannot assign multiple types to name', - 'Incompatible types in assignment (expression has type "Type[Field]', + 'Incompatible types in assignment (expression has type "Type[Field[Any, Any]]', 'DummyArrayField', 'DummyJSONField', 'Argument "encoder" to "JSONField" has incompatible type "DjangoJSONEncoder"; expected "Optional[Type[JSONEncoder]]"', diff --git a/test-data/plugins.ini b/test-data/plugins.ini index b77fc5d..b7b510b 100644 --- a/test-data/plugins.ini +++ b/test-data/plugins.ini @@ -1,4 +1,5 @@ [mypy] -incremental = False +incremental = True +strict_optional = True plugins = mypy_django_plugin.main diff --git a/test-data/typecheck/fields.test b/test-data/typecheck/fields.test index da71746..38f7cbf 100644 --- a/test-data/typecheck/fields.test +++ b/test-data/typecheck/fields.test @@ -6,7 +6,7 @@ class User(models.Model): array = ArrayField(base_field=models.Field()) user = User() -reveal_type(user.array) # E: Revealed type is 'builtins.list[Any]' +reveal_type(user.array) # E: Revealed type is 'builtins.list*[Any]' [CASE array_field_base_field_parsed_into_generic_typevar] from django.db import models @@ -17,8 +17,8 @@ class User(models.Model): members_as_text = ArrayField(base_field=models.CharField(max_length=255)) user = User() -reveal_type(user.members) # E: Revealed type is 'builtins.list[builtins.int*]' -reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list[builtins.str*]' +reveal_type(user.members) # E: Revealed type is 'builtins.list*[builtins.int]' +reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list*[builtins.str]' [CASE test_model_fields_classes_present_as_primitives] from django.db import models @@ -31,13 +31,13 @@ class User(models.Model): text = models.TextField() user = User() -reveal_type(user.id) # E: Revealed type is 'builtins.int' -reveal_type(user.small_int) # E: Revealed type is 'builtins.int' -reveal_type(user.name) # E: Revealed type is 'builtins.str' -reveal_type(user.slug) # E: Revealed type is 'builtins.str' -reveal_type(user.text) # E: Revealed type is 'builtins.str' +reveal_type(user.id) # E: Revealed type is 'builtins.int*' +reveal_type(user.small_int) # E: Revealed type is 'builtins.int*' +reveal_type(user.name) # E: Revealed type is 'builtins.str*' +reveal_type(user.slug) # E: Revealed type is 'builtins.str*' +reveal_type(user.text) # E: Revealed type is 'builtins.str*' -[CASE test_model_field_classes_from_exciting_locations] +[CASE test_model_field_classes_from_existing_locations] from django.db import models from django.contrib.postgres import fields as pg_fields from decimal import Decimal @@ -48,9 +48,9 @@ class Booking(models.Model): some_decimal = models.DecimalField(max_digits=10, decimal_places=5) booking = Booking() -reveal_type(booking.id) # E: Revealed type is 'builtins.int' +reveal_type(booking.id) # E: Revealed type is 'builtins.int*' reveal_type(booking.time_range) # E: Revealed type is 'Any' -reveal_type(booking.some_decimal) # E: Revealed type is 'decimal.Decimal' +reveal_type(booking.some_decimal) # E: Revealed type is 'decimal.Decimal*' [CASE test_add_id_field_if_no_primary_key_defined] from django.db import models @@ -66,7 +66,7 @@ from django.db import models class User(models.Model): my_pk = models.IntegerField(primary_key=True) -reveal_type(User().my_pk) # E: Revealed type is 'builtins.int' +reveal_type(User().my_pk) # E: Revealed type is 'builtins.int*' reveal_type(User().id) # E: Revealed type is 'Any' [out] @@ -102,5 +102,5 @@ class ParentModel(models.Model): pass class MyModel(ParentModel): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) -reveal_type(MyModel().id) # E: Revealed type is 'uuid.UUID' +reveal_type(MyModel().id) # E: Revealed type is 'uuid.UUID*' [out] \ No newline at end of file diff --git a/test-data/typecheck/model_create.test b/test-data/typecheck/model_create.test index cc6f9ee..d4abbe3 100644 --- a/test-data/typecheck/model_create.test +++ b/test-data/typecheck/model_create.test @@ -31,4 +31,15 @@ class Child1(Parent1, Parent2): value = models.IntegerField() class Child4(Child1): value4 = models.IntegerField() -Child4.objects.create(name1='n1', name2='n2', value=1, value4=4) \ No newline at end of file +Child4.objects.create(name1='n1', name2='n2', value=1, value4=4) + +[CASE create_through_related_manager_of_manytomany_field] +from django.db import models + +class Publication(models.Model): + title = models.CharField(max_length=100) +class Article(models.Model): + publications = models.ManyToManyField(Publication) +article = Article.objects.create() +article.publications.create(title='my title') +[out] \ No newline at end of file diff --git a/test-data/typecheck/model_init.test b/test-data/typecheck/model_init.test index a15ca46..9d8f41f 100644 --- a/test-data/typecheck/model_init.test +++ b/test-data/typecheck/model_init.test @@ -71,14 +71,15 @@ from django.db import models class Publisher(models.Model): pass -class PublisherWithCharPK(models.Model): - id = models.IntegerField(primary_key=True) +class PublisherDatetime(models.Model): + dt_pk = models.DateTimeField(primary_key=True) class Book(models.Model): publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE) - publisher_with_char_pk = models.ForeignKey(PublisherWithCharPK, on_delete=models.CASCADE) + publisher_dt = models.ForeignKey(PublisherDatetime, on_delete=models.CASCADE) -Book(publisher_id=1, publisher_with_char_pk_id=1) -Book(publisher_id=1, publisher_with_char_pk_id='hello') # E: Incompatible type for "publisher_with_char_pk_id" of "Book" (got "str", expected "Union[int, Combinable, Literal['']]") +Book(publisher_id=1) +Book(publisher_id=[]) # E: Incompatible type for "publisher_id" of "Book" (got "List[Any]", expected "Union[Combinable, int, str]") +Book(publisher_dt_id=11) # E: Incompatible type for "publisher_dt_id" of "Book" (got "int", expected "Union[str, date, datetime, Combinable]") [out] [CASE setting_value_to_an_array_of_ints] @@ -94,7 +95,7 @@ MyModel(array=array_val) array_val2: List[int] = [1] MyModel(array=array_val2) array_val3: List[str] = ['hello'] -MyModel(array=array_val3) # E: Incompatible type for "array" of "MyModel" (got "List[str]", expected "Sequence[int]") +MyModel(array=array_val3) # E: Incompatible type for "array" of "MyModel" (got "List[str]", expected "Union[Sequence[int], Combinable]") [out] [CASE if_no_explicit_primary_key_id_can_be_passed] @@ -135,20 +136,6 @@ Restaurant(place_ptr=place) Restaurant(place_ptr_id=place.id) [out] -[CASE extract_type_of_init_param_from_set_method] -from typing import Union -from datetime import time - -from django.db import models -class MyField(models.Field): - def __set__(self, instance, value: Union[str, time]) -> None: pass - def __get__(self, instance, owner) -> time: pass -class MyModel(models.Model): - field = MyField() -MyModel(field=time()) -MyModel(field='12:00') -MyModel(field=100) # E: Incompatible type for "field" of "MyModel" (got "int", expected "Union[str, time]") - [CASE charfield_with_integer_choices] from django.db import models class MyModel(models.Model): diff --git a/test-data/typecheck/nullable_fields.test b/test-data/typecheck/nullable_fields.test new file mode 100644 index 0000000..c7c1776 --- /dev/null +++ b/test-data/typecheck/nullable_fields.test @@ -0,0 +1,42 @@ +[CASE nullable_field_with_strict_optional_true] +from django.db import models +class MyModel(models.Model): + text_nullable = models.CharField(max_length=100, null=True) + text = models.CharField(max_length=100) +reveal_type(MyModel().text) # E: Revealed type is 'builtins.str*' +reveal_type(MyModel().text_nullable) # E: Revealed type is 'Union[builtins.str, None]' +MyModel().text = None # E: Incompatible types in assignment (expression has type "None", variable has type "Union[str, int, Combinable]") +MyModel().text_nullable = None +[out] + +[CASE nullable_array_field] +from django.db import models +from django.contrib.postgres.fields import ArrayField + +class MyModel(models.Model): + lst = ArrayField(base_field=models.CharField(max_length=100), null=True) +reveal_type(MyModel().lst) # E: Revealed type is 'Union[builtins.list[builtins.str], None]' +[out] + +[CASE nullable_foreign_key] +from django.db import models + +class Publisher(models.Model): + pass +class Book(models.Model): + publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE, null=True) +reveal_type(Book().publisher) # E: Revealed type is 'Union[main.Publisher, None]' +Book().publisher = 11 # E: Incompatible types in assignment (expression has type "int", variable has type "Union[Publisher, Combinable, None]") +[out] + +[CASE nullable_self_foreign_key] +from django.db import models +class Inventory(models.Model): + parent = models.ForeignKey('self', on_delete=models.SET_NULL, null=True) +parent = Inventory() +core = Inventory(parent_id=parent.id) +reveal_type(core.parent_id) # E: Revealed type is 'Union[builtins.int, None]' +reveal_type(core.parent) # E: Revealed type is 'Union[main.Inventory, None]' +Inventory(parent=None) +Inventory(parent_id=None) +[out] \ No newline at end of file diff --git a/test-data/typecheck/related_fields.test b/test-data/typecheck/related_fields.test index f36d286..b6d265c 100644 --- a/test-data/typecheck/related_fields.test +++ b/test-data/typecheck/related_fields.test @@ -198,7 +198,7 @@ class App(models.Model): pass class Member(models.Model): apps = models.ManyToManyField(to=App, related_name='members') -reveal_type(Member().apps) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.App*]' +reveal_type(Member().apps) # E: Revealed type is 'django.db.models.manager.RelatedManager*[main.App]' reveal_type(App().members) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Member]' [out] @@ -207,7 +207,7 @@ from django.db import models from myapp.models import App class Member(models.Model): apps = models.ManyToManyField(to='myapp.App', related_name='members') -reveal_type(Member().apps) # E: Revealed type is 'django.db.models.manager.RelatedManager[myapp.models.App*]' +reveal_type(Member().apps) # E: Revealed type is 'django.db.models.manager.RelatedManager*[myapp.models.App]' [file myapp/__init__.py] [file myapp/models.py] @@ -226,8 +226,8 @@ reveal_type(User().parent) # E: Revealed type is 'main.User*' [CASE many_to_many_with_self] from django.db import models class User(models.Model): - friends = models.ManyToManyField('self', on_delete=models.CASCADE) -reveal_type(User().friends) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.User*]' + friends = models.ManyToManyField('self') +reveal_type(User().friends) # E: Revealed type is 'django.db.models.manager.RelatedManager*[main.User]' [out] [CASE recursively_checking_for_base_model_in_to_parameter] @@ -274,4 +274,4 @@ Book2(publisher_id=1) Book2(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]") Book2.objects.create(publisher_id=1) Book2.objects.create(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]") -[out] +[out] \ No newline at end of file