From 9eb95fbab3e1cf9f3d6bb5b94346cacc148b7795 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Tue, 12 Feb 2019 03:54:37 +0300 Subject: [PATCH] add BaseManager.create() typechecking --- django-stubs/contrib/admin/models.pyi | 2 +- django-stubs/contrib/auth/models.pyi | 2 +- django-stubs/contrib/redirects/models.pyi | 10 +- django-stubs/contrib/sessions/backends/db.pyi | 5 +- django-stubs/db/models/base.pyi | 6 +- django-stubs/db/models/fields/__init__.pyi | 14 +- django-stubs/views/debug.pyi | 23 ++- mypy_django_plugin/config.py | 21 ++ mypy_django_plugin/helpers.py | 73 +------ mypy_django_plugin/main.py | 92 ++------- mypy_django_plugin/plugins/fields.py | 58 +++++- mypy_django_plugin/plugins/init_create.py | 182 ++++++++++++++++++ mypy_django_plugin/plugins/migrations.py | 6 +- mypy_django_plugin/plugins/models.py | 92 ++++++--- scripts/typecheck_tests.py | 10 +- test-data/typecheck/managers.test | 139 +++++++++++++ test-data/typecheck/model_create.test | 34 ++++ test-data/typecheck/model_init.test | 6 + test-data/typecheck/models.test | 49 ----- 19 files changed, 557 insertions(+), 267 deletions(-) create mode 100644 mypy_django_plugin/config.py create mode 100644 mypy_django_plugin/plugins/init_create.py create mode 100644 test-data/typecheck/managers.test create mode 100644 test-data/typecheck/model_create.test delete mode 100644 test-data/typecheck/models.test diff --git a/django-stubs/contrib/admin/models.pyi b/django-stubs/contrib/admin/models.pyi index f08d362..edfb37d 100644 --- a/django-stubs/contrib/admin/models.pyi +++ b/django-stubs/contrib/admin/models.pyi @@ -11,7 +11,7 @@ CHANGE: int DELETION: int ACTION_FLAG_CHOICES: Any -class LogEntryManager(models.Manager): +class LogEntryManager(models.Manager["LogEntry"]): def log_action( self, user_id: int, diff --git a/django-stubs/contrib/auth/models.pyi b/django-stubs/contrib/auth/models.pyi index e3ea354..8147c2c 100644 --- a/django-stubs/contrib/auth/models.pyi +++ b/django-stubs/contrib/auth/models.pyi @@ -16,7 +16,7 @@ class Permission(models.Model): content_type_id: int name: models.CharField = ... content_type: models.ForeignKey[ContentType] = ... - codename: str = ... + codename: models.CharField = ... def natural_key(self) -> Tuple[str, str, str]: ... class GroupManager(models.Manager): diff --git a/django-stubs/contrib/redirects/models.pyi b/django-stubs/contrib/redirects/models.pyi index dd16fcf..b732632 100644 --- a/django-stubs/contrib/redirects/models.pyi +++ b/django-stubs/contrib/redirects/models.pyi @@ -1,10 +1,6 @@ -from typing import Any, Optional - from django.db import models class Redirect(models.Model): - id: None - site_id: int - site: Any = ... - old_path: str = ... - new_path: str = ... + site: models.ForeignKey = ... + old_path: models.CharField = ... + new_path: models.CharField = ... diff --git a/django-stubs/contrib/sessions/backends/db.pyi b/django-stubs/contrib/sessions/backends/db.pyi index 03240fa..f56e83a 100644 --- a/django-stubs/contrib/sessions/backends/db.pyi +++ b/django-stubs/contrib/sessions/backends/db.pyi @@ -1,13 +1,14 @@ -from typing import Any, Dict, Optional, Type, Union +from typing import Dict, Optional, Type, Union from django.contrib.sessions.backends.base import SessionBase from django.contrib.sessions.base_session import AbstractBaseSession from django.contrib.sessions.models import Session +from django.core.signing import Serializer from django.db.models.base import Model class SessionStore(SessionBase): accessed: bool - serializer: Type[django.core.signing.JSONSerializer] + serializer: Type[Serializer] def __init__(self, session_key: Optional[str] = ...) -> None: ... @classmethod def get_model_class(cls) -> Type[Session]: ... diff --git a/django-stubs/db/models/base.pyi b/django-stubs/db/models/base.pyi index 4e1ca23..185333c 100644 --- a/django-stubs/db/models/base.pyi +++ b/django-stubs/db/models/base.pyi @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, ClassVar, Sequence +from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, ClassVar, Sequence, Generic from django.db.models.manager import Manager @@ -10,9 +10,9 @@ class Model(metaclass=ModelBase): class DoesNotExist(Exception): ... class Meta: ... _meta: Any + _default_manager: Manager[Model] pk: Any = ... - objects: Manager[Model] - def __init__(self, *args, **kwargs) -> None: ... + def __init__(self: _Self, *args, **kwargs) -> None: ... def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ... def full_clean(self, exclude: Optional[List[str]] = ..., validate_unique: bool = ...) -> None: ... def clean_fields(self, exclude: List[str] = ...) -> None: ... diff --git a/django-stubs/db/models/fields/__init__.pyi b/django-stubs/db/models/fields/__init__.pyi index 12428c0..dc51a42 100644 --- a/django-stubs/db/models/fields/__init__.pyi +++ b/django-stubs/db/models/fields/__init__.pyi @@ -1,6 +1,6 @@ import uuid from datetime import date, time, datetime, timedelta -from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar +from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar, Generic import decimal from typing_extensions import Literal @@ -14,7 +14,7 @@ from django.forms import Widget, Field as FormField from .mixins import NOT_PROVIDED as NOT_PROVIDED _Choice = Tuple[Any, Any] -_ChoiceNamedGroup = Union[Tuple[str, Iterable[_Choice]], Tuple[str, Any]] +_ChoiceNamedGroup = Tuple[str, Iterable[_Choice]] _FieldChoices = Iterable[Union[_Choice, _ChoiceNamedGroup]] _ValidatorCallable = Callable[..., None] @@ -76,7 +76,7 @@ class SmallIntegerField(IntegerField): ... class BigIntegerField(IntegerField): ... class FloatField(Field): - def __set__(self, instance, value: Union[float, int, Combinable]) -> float: ... + def __set__(self, instance, value: Union[float, int, str, Combinable]) -> float: ... def __get__(self, instance, owner) -> float: ... class DecimalField(Field): @@ -102,7 +102,7 @@ class DecimalField(Field): validators: Iterable[_ValidatorCallable] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ..., ): ... - def __set__(self, instance, value: Union[str, Combinable]) -> decimal.Decimal: ... + def __set__(self, instance, value: Union[str, float, decimal.Decimal, Combinable]) -> decimal.Decimal: ... def __get__(self, instance, owner) -> decimal.Decimal: ... class AutoField(Field): @@ -167,15 +167,15 @@ class EmailField(CharField): ... class URLField(CharField): ... class TextField(Field): - def __set__(self, instance, value: str) -> None: ... + def __set__(self, instance, value: Union[str, Combinable]) -> None: ... def __get__(self, instance, owner) -> str: ... class BooleanField(Field): - def __set__(self, instance, value: bool) -> None: ... + def __set__(self, instance, value: Union[bool, Combinable]) -> None: ... def __get__(self, instance, owner) -> bool: ... class NullBooleanField(Field): - def __set__(self, instance, value: Optional[bool]) -> None: ... + def __set__(self, instance, value: Optional[Union[bool, Combinable]]) -> None: ... def __get__(self, instance, owner) -> Optional[bool]: ... class IPAddressField(Field): diff --git a/django-stubs/views/debug.pyi b/django-stubs/views/debug.pyi index e3b8a92..52fd382 100644 --- a/django-stubs/views/debug.pyi +++ b/django-stubs/views/debug.pyi @@ -1,9 +1,8 @@ from importlib.abc import SourceLoader -from typing import Any, Callable, Dict, List, Optional, Type, Union from types import TracebackType +from typing import Any, Callable, Dict, List, MutableMapping, Optional, Type, Union -from django.core.handlers.wsgi import WSGIRequest -from django.http.request import QueryDict +from django.http.request import HttpRequest, QueryDict from django.http.response import Http404, HttpResponse from django.utils.safestring import SafeText @@ -19,21 +18,21 @@ def cleanse_setting(key: Union[int, str], value: Any) -> Any: ... def get_safe_settings() -> Dict[str, Any]: ... def technical_500_response(request: Any, exc_type: Any, exc_value: Any, tb: Any, status_code: int = ...): ... def get_default_exception_reporter_filter() -> ExceptionReporterFilter: ... -def get_exception_reporter_filter(request: Optional[WSGIRequest]) -> ExceptionReporterFilter: ... +def get_exception_reporter_filter(request: Optional[HttpRequest]) -> ExceptionReporterFilter: ... class ExceptionReporterFilter: def get_post_parameters(self, request: Any): ... def get_traceback_frame_variables(self, request: Any, tb_frame: Any): ... class SafeExceptionReporterFilter(ExceptionReporterFilter): - def is_active(self, request: Optional[WSGIRequest]) -> bool: ... - def get_cleansed_multivaluedict(self, request: WSGIRequest, multivaluedict: QueryDict) -> QueryDict: ... - def get_post_parameters(self, request: Optional[WSGIRequest]) -> Union[Dict[Any, Any], QueryDict]: ... - def cleanse_special_types(self, request: Optional[WSGIRequest], value: Any) -> Any: ... + def is_active(self, request: Optional[HttpRequest]) -> bool: ... + def get_cleansed_multivaluedict(self, request: HttpRequest, multivaluedict: QueryDict) -> QueryDict: ... + def get_post_parameters(self, request: Optional[HttpRequest]) -> MutableMapping[str, Any]: ... + def cleanse_special_types(self, request: Optional[HttpRequest], value: Any) -> Any: ... def get_traceback_frame_variables(self, request: Any, tb_frame: Any): ... class ExceptionReporter: - request: Optional[WSGIRequest] = ... + request: Optional[HttpRequest] = ... filter: ExceptionReporterFilter = ... exc_type: None = ... exc_value: Optional[str] = ... @@ -44,7 +43,7 @@ class ExceptionReporter: postmortem: None = ... def __init__( self, - request: Optional[WSGIRequest], + request: Optional[HttpRequest], exc_type: Optional[Type[BaseException]], exc_value: Optional[Union[str, BaseException]], tb: Optional[TracebackType], @@ -63,5 +62,5 @@ class ExceptionReporter: module_name: Optional[str] = None, ): ... -def technical_404_response(request: WSGIRequest, exception: Http404) -> HttpResponse: ... -def default_urlconf(request: WSGIRequest) -> HttpResponse: ... +def technical_404_response(request: HttpRequest, exception: Http404) -> HttpResponse: ... +def default_urlconf(request: HttpRequest) -> HttpResponse: ... diff --git a/mypy_django_plugin/config.py b/mypy_django_plugin/config.py new file mode 100644 index 0000000..34d53f6 --- /dev/null +++ b/mypy_django_plugin/config.py @@ -0,0 +1,21 @@ +from configparser import ConfigParser +from typing import Optional + +from dataclasses import dataclass + + +@dataclass +class Config: + django_settings_module: Optional[str] = None + ignore_missing_settings: bool = False + + @classmethod + def from_config_file(self, fpath: str) -> 'Config': + ini_config = ConfigParser() + ini_config.read(fpath) + if not ini_config.has_section('mypy_django_plugin'): + raise ValueError('Invalid config file: no [mypy_django_plugin] section') + return Config(django_settings_module=ini_config.get('mypy_django_plugin', 'django_settings', + fallback=None), + ignore_missing_settings=ini_config.get('mypy_django_plugin', 'ignore_missing_settings', + fallback=False)) diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index 65a0c14..6278bd9 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -1,10 +1,9 @@ import typing from typing import Dict, Optional -from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var, AssignmentStmt, \ - CallExpr +from mypy.nodes import Expression, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo from mypy.plugin import FunctionContext -from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType +from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' FIELD_FULLNAME = 'django.db.models.fields.Field' @@ -119,74 +118,6 @@ def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance: return reparametrize_with(type_to_fill, typevar_values) -def extract_field_setter_type(tp: Instance) -> Optional[Type]: - if tp.type.has_base(FIELD_FULLNAME): - set_method = tp.type.get_method('__set__') - if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType): - if 'value' in set_method.type.arg_names: - set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')] - if isinstance(set_value_type, Instance): - set_value_type = fill_typevars(tp, set_value_type) - return set_value_type - elif isinstance(set_value_type, UnionType): - items_no_typevars = [] - for item in set_value_type.items: - if isinstance(item, Instance): - item = fill_typevars(tp, item) - items_no_typevars.append(item) - return UnionType(items_no_typevars) - - get_method = tp.type.get_method('__get__') - if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType): - return get_method.type.ret_type - # GenericForeignKey - if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME): - return AnyType(TypeOfAny.special_form) - return None - - -def extract_primary_key_type(model: TypeInfo) -> Optional[Type]: - # only primary keys defined in current class for now - for stmt in model.defn.defs.body: - if isinstance(stmt, AssignmentStmt) and isinstance(stmt.rvalue, CallExpr): - name_expr = stmt.lvalues[0] - if isinstance(name_expr, NameExpr): - name = name_expr.name - if 'primary_key' in stmt.rvalue.arg_names: - is_primary_key = stmt.rvalue.args[stmt.rvalue.arg_names.index('primary_key')] - if is_primary_key: - return extract_field_setter_type(model.names[name].type) - return None - - -def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]: - expected_types: Dict[str, Type] = {} - - primary_key_type = extract_primary_key_type(model) - if not primary_key_type: - # no explicit primary key, set pk to Any and add id - primary_key_type = AnyType(TypeOfAny.special_form) - expected_types['id'] = ctx.api.named_generic_type('builtins.int', []) - - expected_types['pk'] = primary_key_type - - for base in model.mro: - for name, sym in base.names.items(): - if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): - tp = sym.node.type - field_type = extract_field_setter_type(tp) - if tp.type.fullname() in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}: - ref_to_model = tp.args[0] - if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME): - primary_key_type = extract_primary_key_type(ref_to_model.type) - if not primary_key_type: - primary_key_type = AnyType(TypeOfAny.special_form) - expected_types[name + '_id'] = primary_key_type - if field_type: - expected_types[name] = field_type - return expected_types - - def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]: """Return the expression for the specific argument. diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index d5224d4..d7b316c 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,8 +1,6 @@ import os -from configparser import ConfigParser -from typing import Callable, Dict, Optional, Set, cast +from typing import Callable, Dict, Optional, cast -from dataclasses import dataclass from mypy.checker import TypeChecker from mypy.nodes import TypeInfo from mypy.options import Options @@ -10,7 +8,9 @@ from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin from mypy.types import Instance, Type from mypy_django_plugin import helpers, monkeypatch -from mypy_django_plugin.plugins.fields import determine_type_of_array_field +from mypy_django_plugin.config import Config +from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class +from mypy_django_plugin.plugins.init_create import redefine_and_typecheck_model_init, redefine_and_typecheck_model_create from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations from mypy_django_plugin.plugins.models import process_model_class from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with @@ -56,81 +56,6 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type: return ret -def extract_base_pointer_args(model: TypeInfo) -> Set[str]: - pointer_args: Set[str] = set() - for base in model.bases: - if base.type.has_base(helpers.MODEL_CLASS_FULLNAME): - parent_name = base.type.name().lower() - pointer_args.add(f'{parent_name}_ptr') - pointer_args.add(f'{parent_name}_ptr_id') - return pointer_args - - -def redefine_model_init(ctx: FunctionContext) -> Type: - assert isinstance(ctx.default_return_type, Instance) - - api = cast(TypeChecker, ctx.api) - model: TypeInfo = ctx.default_return_type.type - - expected_types = helpers.extract_expected_types(ctx, model) - # order is preserved, can use for positionals - positional_names = list(expected_types.keys()) - positional_names.remove('pk') - visited_positionals = set() - - # check positionals - for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])): - actual_pos_name = positional_names[i] - api.check_subtype(actual_pos_type, expected_types[actual_pos_name], - ctx.context, - 'Incompatible type for "{}" of "{}"'.format(actual_pos_name, - model.name()), - 'got', 'expected') - visited_positionals.add(actual_pos_name) - - # extract name of base models for _ptr - base_pointer_args = extract_base_pointer_args(model) - - # check kwargs - for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])): - if actual_name in base_pointer_args: - # parent_ptr args are not supported - continue - if actual_name in visited_positionals: - continue - if actual_name is None: - # unpacked dict as kwargs is not supported - continue - if actual_name not in expected_types: - ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name, - model.name()), - ctx.context) - continue - api.check_subtype(actual_type, expected_types[actual_name], - ctx.context, - 'Incompatible type for "{}" of "{}"'.format(actual_name, - model.name()), - 'got', 'expected') - return ctx.default_return_type - - -@dataclass -class Config: - django_settings_module: Optional[str] = None - ignore_missing_settings: bool = False - - @classmethod - def from_config_file(self, fpath: str) -> 'Config': - ini_config = ConfigParser() - ini_config.read(fpath) - if not ini_config.has_section('mypy_django_plugin'): - raise ValueError('Invalid config file: no [mypy_django_plugin] section') - return Config(django_settings_module=ini_config.get('mypy_django_plugin', 'django_settings', - fallback=None), - ignore_missing_settings=ini_config.get('mypy_django_plugin', 'ignore_missing_settings', - fallback=False)) - - class DjangoPlugin(Plugin): def __init__(self, options: Options) -> None: super().__init__(options) @@ -194,11 +119,18 @@ class DjangoPlugin(Plugin): sym = self.lookup_fully_qualified(fullname) if sym and isinstance(sym.node, TypeInfo): + if sym.node.has_base(helpers.FIELD_FULLNAME): + return record_field_properties_into_outer_model_class if sym.node.metadata.get('django', {}).get('generated_init'): - return redefine_model_init + return redefine_and_typecheck_model_init def get_method_hook(self, fullname: str ) -> Optional[Callable[[MethodContext], Type]]: + manager_classes = self._get_current_manager_bases() + class_fullname, _, method_name = fullname.rpartition('.') + if class_fullname in manager_classes and method_name == 'create': + return redefine_and_typecheck_model_create + if fullname in {'django.apps.registry.Apps.get_model', 'django.db.migrations.state.StateApps.get_model'}: return determine_model_cls_from_string_for_migrations diff --git a/mypy_django_plugin/plugins/fields.py b/mypy_django_plugin/plugins/fields.py index f99baed..2a3caab 100644 --- a/mypy_django_plugin/plugins/fields.py +++ b/mypy_django_plugin/plugins/fields.py @@ -1,7 +1,12 @@ +from typing import cast + +from mypy.checker import TypeChecker +from mypy.nodes import ListExpr, NameExpr, TupleExpr from mypy.plugin import FunctionContext -from mypy.types import Type, Instance +from mypy.types import Instance, TupleType, Type from mypy_django_plugin import helpers +from mypy_django_plugin.plugins.models import iter_over_assignments def determine_type_of_array_field(ctx: FunctionContext) -> Type: @@ -16,3 +21,54 @@ def determine_type_of_array_field(ctx: FunctionContext) -> Type: return ctx.api.named_generic_type(ctx.context.callee.fullname, args=[get_method.type.ret_type]) + + +def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> Type: + api = cast(TypeChecker, ctx.api) + outer_model = api.scope.active_class() + if outer_model is None or not outer_model.has_base(helpers.MODEL_CLASS_FULLNAME): + # outside models.Model class, undetermined + return ctx.default_return_type + + field_name = None + for name_expr, stmt in iter_over_assignments(outer_model.defn): + if stmt == ctx.context and isinstance(name_expr, NameExpr): + field_name = name_expr.name + break + if field_name is None: + return ctx.default_return_type + + fields_metadata = outer_model.metadata.setdefault('django', {}).setdefault('fields', {}) + + # primary key + is_primary_key = False + primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key') + if primary_key_arg: + is_primary_key = helpers.parse_bool(primary_key_arg) + fields_metadata[field_name] = {'primary_key': is_primary_key} + + # choices + choices_arg = helpers.get_argument_by_name(ctx, 'choices') + if choices_arg and isinstance(choices_arg, (TupleExpr, ListExpr)): + # iterable of 2 element tuples of two kinds + _, analyzed_choices = api.analyze_iterable_item_type(choices_arg) + if isinstance(analyzed_choices, TupleType): + first_element_type = analyzed_choices.items[0] + if isinstance(first_element_type, Instance): + fields_metadata[field_name]['choices'] = first_element_type.type.fullname() + + # nullability + null_arg = helpers.get_argument_by_name(ctx, 'null') + is_nullable = False + if null_arg: + is_nullable = helpers.parse_bool(null_arg) + fields_metadata[field_name]['null'] = is_nullable + + # is_blankable + blank_arg = helpers.get_argument_by_name(ctx, 'blank') + is_blankable = False + if blank_arg: + is_blankable = helpers.parse_bool(blank_arg) + fields_metadata[field_name]['blank'] = is_blankable + + return ctx.default_return_type diff --git a/mypy_django_plugin/plugins/init_create.py b/mypy_django_plugin/plugins/init_create.py new file mode 100644 index 0000000..cef776a --- /dev/null +++ b/mypy_django_plugin/plugins/init_create.py @@ -0,0 +1,182 @@ +from typing import Dict, Optional, Set, cast, Any + +from mypy.checker import TypeChecker +from mypy.nodes import FuncDef, TypeInfo, Var +from mypy.plugin import FunctionContext, MethodContext +from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, UnionType + +from mypy_django_plugin import helpers + + +def extract_base_pointer_args(model: TypeInfo) -> Set[str]: + pointer_args: Set[str] = set() + for base in model.bases: + if base.type.has_base(helpers.MODEL_CLASS_FULLNAME): + parent_name = base.type.name().lower() + pointer_args.add(f'{parent_name}_ptr') + pointer_args.add(f'{parent_name}_ptr_id') + return pointer_args + + +def redefine_and_typecheck_model_init(ctx: FunctionContext) -> Type: + assert isinstance(ctx.default_return_type, Instance) + + api = cast(TypeChecker, ctx.api) + model: TypeInfo = ctx.default_return_type.type + + expected_types = extract_expected_types(ctx, model) + # order is preserved, can use for positionals + positional_names = list(expected_types.keys()) + positional_names.remove('pk') + visited_positionals = set() + + # check positionals + for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])): + actual_pos_name = positional_names[i] + api.check_subtype(actual_pos_type, expected_types[actual_pos_name], + ctx.context, + 'Incompatible type for "{}" of "{}"'.format(actual_pos_name, + model.name()), + 'got', 'expected') + visited_positionals.add(actual_pos_name) + + # extract name of base models for _ptr + base_pointer_args = extract_base_pointer_args(model) + + # check kwargs + for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])): + if actual_name in base_pointer_args: + # parent_ptr args are not supported + continue + if actual_name in visited_positionals: + continue + if actual_name is None: + # unpacked dict as kwargs is not supported + continue + if actual_name not in expected_types: + ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name, + model.name()), + ctx.context) + continue + api.check_subtype(actual_type, expected_types[actual_name], + ctx.context, + 'Incompatible type for "{}" of "{}"'.format(actual_name, + model.name()), + 'got', 'expected') + return ctx.default_return_type + + +def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type: + api = cast(TypeChecker, ctx.api) + if isinstance(ctx.type, Instance) and len(ctx.type.args) > 0: + model: TypeInfo = ctx.type.args[0].type + else: + if isinstance(ctx.default_return_type, AnyType): + return ctx.default_return_type + model: TypeInfo = ctx.default_return_type.type + + # extract name of base models for _ptr + base_pointer_args = extract_base_pointer_args(model) + expected_types = extract_expected_types(ctx, model) + + for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]): + if actual_name in base_pointer_args: + # parent_ptr args are not supported + continue + if actual_name is None: + # unpacked dict as kwargs is not supported + continue + if actual_name not in expected_types: + api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name, + model.name()), + ctx.context) + continue + api.check_subtype(actual_type, expected_types[actual_name], + ctx.context, + 'Incompatible type for "{}" of "{}"'.format(actual_name, + model.name()), + 'got', 'expected') + + return ctx.default_return_type + + +def extract_field_setter_type(tp: Instance) -> Optional[Type]: + if not isinstance(tp, Instance): + return None + if tp.type.has_base(helpers.FIELD_FULLNAME): + set_method = tp.type.get_method('__set__') + if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType): + if 'value' in set_method.type.arg_names: + set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')] + if isinstance(set_value_type, Instance): + set_value_type = helpers.fill_typevars(tp, set_value_type) + return set_value_type + elif isinstance(set_value_type, UnionType): + items_no_typevars = [] + for item in set_value_type.items: + if isinstance(item, Instance): + item = helpers.fill_typevars(tp, item) + items_no_typevars.append(item) + return UnionType(items_no_typevars) + + get_method = tp.type.get_method('__get__') + if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType): + return get_method.type.ret_type + # GenericForeignKey + if tp.type.has_base(helpers.GENERIC_FOREIGN_KEY_FULLNAME): + return AnyType(TypeOfAny.special_form) + return None + + +def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]: + return model.metadata.setdefault('django', {}).setdefault('fields', {}) + + +def extract_primary_key_type(model: TypeInfo) -> Optional[Type]: + for field_name, props in get_fields_metadata(model).items(): + is_primary_key = props.get('primary_key', False) + if is_primary_key: + return extract_field_setter_type(model.names[field_name].type) + return None + + +def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]: + field_metadata = get_fields_metadata(model).get(field_name, {}) + if 'choices' in field_metadata: + return field_metadata['choices'] + return None + + +def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]: + expected_types: Dict[str, Type] = {} + + primary_key_type = extract_primary_key_type(model) + if not primary_key_type: + # no explicit primary key, set pk to Any and add id + primary_key_type = AnyType(TypeOfAny.special_form) + expected_types['id'] = ctx.api.named_generic_type('builtins.int', []) + + expected_types['pk'] = primary_key_type + + for base in model.mro: + for name, sym in base.names.items(): + if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): + tp = sym.node.type + field_type = extract_field_setter_type(tp) + if field_type is None: + continue + + choices_type_fullname = extract_choices_type(model, name) + if choices_type_fullname: + field_type = UnionType([field_type, ctx.api.named_generic_type(choices_type_fullname, [])]) + + if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}: + ref_to_model = tp.args[0] + if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME): + primary_key_type = extract_primary_key_type(ref_to_model.type) + if not primary_key_type: + primary_key_type = AnyType(TypeOfAny.special_form) + expected_types[name + '_id'] = primary_key_type + if field_type: + expected_types[name] = field_type + return expected_types diff --git a/mypy_django_plugin/plugins/migrations.py b/mypy_django_plugin/plugins/migrations.py index 65500d4..b6baad8 100644 --- a/mypy_django_plugin/plugins/migrations.py +++ b/mypy_django_plugin/plugins/migrations.py @@ -1,9 +1,9 @@ -from typing import cast, Optional +from typing import Optional, cast from mypy.checker import TypeChecker -from mypy.nodes import TypeInfo, Expression, StrExpr, NameExpr, RefExpr, Var +from mypy.nodes import Expression, StrExpr, TypeInfo from mypy.plugin import MethodContext -from mypy.types import Type, Instance, TypeType +from mypy.types import Instance, Type, TypeType from mypy_django_plugin import helpers diff --git a/mypy_django_plugin/plugins/models.py b/mypy_django_plugin/plugins/models.py index 7f65960..62cb1d3 100644 --- a/mypy_django_plugin/plugins/models.py +++ b/mypy_django_plugin/plugins/models.py @@ -1,13 +1,13 @@ from abc import ABCMeta, abstractmethod -from typing import Dict, Iterator, Optional, Tuple, cast +from typing import Dict, Iterator, List, Optional, Tuple, cast import dataclasses -from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, \ - MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, Argument, ARG_STAR2, ARG_STAR +from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, AssignmentStmt, CallExpr, ClassDef, Context, Expression, IndexExpr, \ + Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var from mypy.plugin import ClassDefContext from mypy.plugins.common import add_method from mypy.semanal import SemanticAnalyzerPass2 -from mypy.types import Instance, AnyType, TypeOfAny, NoneTyp +from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny from mypy_django_plugin import helpers @@ -27,18 +27,20 @@ class ModelClassInitializer(metaclass=ABCMeta): return metaclass_sym.node return None - def is_abstract_model(self) -> bool: + def get_meta_attribute(self, name: str) -> Optional[Expression]: meta_node = self.get_nested_meta_node() if meta_node is None: - return False + return None for lvalue, rvalue in iter_over_assignments(meta_node.defn): - if isinstance(lvalue, NameExpr) and lvalue.name == 'abstract': - is_abstract = self.api.parse_bool(rvalue) - if is_abstract: - # abstract model do not need 'objects' queryset - return True - return False + if isinstance(lvalue, NameExpr) and lvalue.name == name: + return rvalue + + def is_abstract_model(self) -> bool: + is_abstract_expr = self.get_meta_attribute('abstract') + if is_abstract_expr is None: + return False + return self.api.parse_bool(is_abstract_expr) def add_new_node_to_model_class(self, name: str, typ: Instance) -> None: var = Var(name=name, type=typ) @@ -93,25 +95,65 @@ class InjectAnyAsBaseForNestedMeta(ModelClassInitializer): meta_node.fallback_to_any = True +def get_model_argument(manager_info: TypeInfo) -> Optional[Instance]: + for base in manager_info.bases: + if base.args: + model_arg = base.args[0] + if isinstance(model_arg, Instance) and model_arg.type.has_base(helpers.MODEL_CLASS_FULLNAME): + return model_arg + return None + + class AddDefaultObjectsManager(ModelClassInitializer): - def is_default_objects_attr(self, sym: SymbolTableNode) -> bool: - return sym.fullname == helpers.MODEL_CLASS_FULLNAME + '.' + 'objects' + def add_new_manager(self, name: str, manager_type: Optional[Instance]) -> None: + if manager_type is None: + return None + self.add_new_node_to_model_class(name, manager_type) + + def add_private_default_manager(self, manager_type: Optional[Instance]) -> None: + if manager_type is None: + return None + self.add_new_node_to_model_class('_default_manager', manager_type) + + def get_existing_managers(self) -> List[Tuple[str, TypeInfo]]: + managers = [] + for base in self.model_classdef.info.mro: + for name_expr, member_expr in iter_call_assignments(base.defn): + manager_name = name_expr.name + callee_expr = member_expr.callee + if isinstance(callee_expr, IndexExpr): + callee_expr = callee_expr.analyzed.expr + if isinstance(callee_expr, (MemberExpr, NameExpr)) \ + and isinstance(callee_expr.node, TypeInfo) \ + and callee_expr.node.has_base(helpers.BASE_MANAGER_CLASS_FULLNAME): + managers.append((manager_name, callee_expr.node)) + return managers def run(self) -> None: - existing_objects_sym = self.model_classdef.info.get('objects') - if (existing_objects_sym is not None - and not self.is_default_objects_attr(existing_objects_sym)): - return None + existing_managers = self.get_existing_managers() + if existing_managers: + first_manager_type = None + for manager_name, manager_type_info in existing_managers: + manager_type = Instance(manager_type_info, args=[Instance(self.model_classdef.info, [])]) + self.add_new_manager(name=manager_name, manager_type=manager_type) + if first_manager_type is None: + first_manager_type = manager_type + else: + if self.is_abstract_model(): + # abstract models do not need 'objects' queryset + return None + + first_manager_type = self.api.named_type_or_none(helpers.MANAGER_CLASS_FULLNAME, + args=[Instance(self.model_classdef.info, [])]) + self.add_new_manager('objects', manager_type=first_manager_type) if self.is_abstract_model(): - # abstract models do not need 'objects' queryset return None - - typ = self.api.named_type_or_none(helpers.MANAGER_CLASS_FULLNAME, - args=[Instance(self.model_classdef.info, [])]) - if not typ: - return None - self.add_new_node_to_model_class('objects', typ) + default_manager_name_expr = self.get_meta_attribute('default_manager_name') + if isinstance(default_manager_name_expr, StrExpr): + self.add_private_default_manager(self.model_classdef.info.get(default_manager_name_expr.value).type) + else: + self.add_private_default_manager(first_manager_type) class AddIdAttributeIfPrimaryKeyTrueIsNotSet(ModelClassInitializer): diff --git a/scripts/typecheck_tests.py b/scripts/typecheck_tests.py index 9662de0..d5ca59c 100644 --- a/scripts/typecheck_tests.py +++ b/scripts/typecheck_tests.py @@ -75,7 +75,9 @@ IGNORED_ERRORS = { 'Argument 1 to "FileWrapper" has incompatible type "StringIO"; expected "IO[bytes]"', 'Incompatible types in assignment', '"object" not callable', - 'Incompatible type for "pk" of "Collector" (got "int", expected "str")' + 'Incompatible type for "pk" of "Collector" (got "int", expected "str")', + re.compile('Unexpected attribute "[a-z]+" for model "Model"'), + 'Unexpected attribute "two_id" for model "CyclicOne"' ], 'aggregation': [ 'Incompatible types in assignment (expression has type "QuerySet[Any]", variable has type "List[Any]")', @@ -207,7 +209,8 @@ IGNORED_ERRORS = { 'Incompatible types in assignment (expression has type "Type[Field]', 'DummyArrayField', 'DummyJSONField', - 'Argument "encoder" to "JSONField" has incompatible type "DjangoJSONEncoder"; expected "Optional[Type[JSONEncoder]]"' + 'Argument "encoder" to "JSONField" has incompatible type "DjangoJSONEncoder"; expected "Optional[Type[JSONEncoder]]"', + 'for model "CITestModel"' ], 'properties': [ re.compile('Unexpected attribute "(full_name|full_name_2)" for model "Person"') @@ -430,7 +433,6 @@ TESTS_DIRS = [ 'model_options', 'model_package', 'model_regress', - # not practical 'modeladmin', # TODO: 'multiple_database', 'mutually_referential', @@ -495,9 +497,7 @@ TESTS_DIRS = [ 'transaction_hooks', 'transactions', 'unmanaged_models', - 'update', - 'update_only_fields', 'urlpatterns', diff --git a/test-data/typecheck/managers.test b/test-data/typecheck/managers.test new file mode 100644 index 0000000..88ad066 --- /dev/null +++ b/test-data/typecheck/managers.test @@ -0,0 +1,139 @@ +[CASE test_every_model_has_objects_queryset_available] +from django.db import models +class User(models.Model): + pass +reveal_type(User.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.User]' +reveal_type(User.objects.get()) # E: Revealed type is 'main.User*' + +[CASE every_model_has_its_own_objects_queryset] +from django.db import models +class Parent(models.Model): + pass +class Child(Parent): + pass +reveal_type(Parent.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.Parent]' +reveal_type(Child.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.Child]' +[out] + +[CASE if_manager_is_defined_on_model_do_not_add_objects] +from django.db import models + +class MyModel(models.Model): + authors = models.Manager[MyModel]() +reveal_type(MyModel.authors) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]' +reveal_type(MyModel.objects) # E: Revealed type is 'Any' +[out] + +[CASE test_model_objects_attribute_present_in_case_of_model_cls_passed_as_generic_parameter] +from typing import TypeVar, Generic, Type +from django.db import models + +_T = TypeVar('_T', bound=models.Model) +class Base(Generic[_T]): + def __init__(self, model_cls: Type[_T]): + self.model_cls = model_cls + reveal_type(self.model_cls._default_manager) # E: Revealed type is 'django.db.models.manager.Manager[django.db.models.base.Model]' +class MyModel(models.Model): + pass +base_instance = Base(MyModel) +reveal_type(base_instance.model_cls._default_manager) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]' + +class Child(Base[MyModel]): + def method(self) -> None: + reveal_type(self.model_cls._default_manager) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]' + +[CASE if_custom_manager_defined_it_is_set_to_default_manager] +from typing import TypeVar +from django.db import models +_T = TypeVar('_T', bound=models.Model) +class CustomManager(models.Manager[_T]): + pass +class MyModel(models.Model): + manager = CustomManager[MyModel]() +reveal_type(MyModel._default_manager) # E: Revealed type is 'main.CustomManager[main.MyModel]' + +[CASE if_default_manager_name_is_passed_set_default_manager_to_it] +from typing import TypeVar +from django.db import models + +_T = TypeVar('_T', bound=models.Model) +class Manager1(models.Manager[_T]): + pass +class Manager2(models.Manager[_T]): + pass +class MyModel(models.Model): + class Meta: + default_manager_name = 'm2' + m1: Manager1[MyModel] + m2: Manager2[MyModel] +reveal_type(MyModel._default_manager) # E: Revealed type is 'main.Manager2[main.MyModel]' + +[CASE test_leave_as_is_if_objects_is_set_and_fill_typevars_with_outer_class] +from django.db import models + +class UserManager(models.Manager[MyUser]): + def get_or_404(self) -> MyUser: + pass + +class MyUser(models.Model): + objects = UserManager() + +reveal_type(MyUser.objects) # E: Revealed type is 'main.UserManager[main.MyUser]' +reveal_type(MyUser.objects.get()) # E: Revealed type is 'main.MyUser*' +reveal_type(MyUser.objects.get_or_404()) # E: Revealed type is 'main.MyUser' + +[CASE model_imported_from_different_file] +from django.db import models +from models.main import Inventory + +class Band(models.Model): + pass +reveal_type(Inventory.objects) # E: Revealed type is 'django.db.models.manager.Manager[models.main.Inventory]' +reveal_type(Band.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.Band]' +[file models/__init__.py] +[file models/main.py] +from django.db import models +class Inventory(models.Model): + pass + +[CASE managers_that_defined_on_other_models_do_not_influence] +from django.db import models + +class AbstractPerson(models.Model): + abstract_persons = models.Manager[AbstractPerson]() +class PublishedBookManager(models.Manager[Book]): + pass +class AnnotatedBookManager(models.Manager[Book]): + pass +class Book(models.Model): + title = models.CharField(max_length=50) + published_objects = PublishedBookManager() + annotated_objects = AnnotatedBookManager() + +reveal_type(AbstractPerson.abstract_persons) # E: Revealed type is 'django.db.models.manager.Manager[main.AbstractPerson]' +reveal_type(Book.published_objects) # E: Revealed type is 'main.PublishedBookManager[main.Book]' +Book.published_objects.create(title='hello') +reveal_type(Book.annotated_objects) # E: Revealed type is 'main.AnnotatedBookManager[main.Book]' +Book.annotated_objects.create(title='hello') +[out] + +[CASE managers_inherited_from_abstract_classes_multiple_inheritance] +from django.db import models +class CustomManager1(models.Manager[AbstractBase1]): + pass +class AbstractBase1(models.Model): + class Meta: + abstract = True + name = models.CharField(max_length=50) + manager1 = CustomManager1() +class CustomManager2(models.Manager[AbstractBase2]): + pass +class AbstractBase2(models.Model): + class Meta: + abstract = True + value = models.CharField(max_length=50) + restricted = CustomManager2() + +class Child(AbstractBase1, AbstractBase2): + pass +[out] \ No newline at end of file diff --git a/test-data/typecheck/model_create.test b/test-data/typecheck/model_create.test new file mode 100644 index 0000000..3a15fa2 --- /dev/null +++ b/test-data/typecheck/model_create.test @@ -0,0 +1,34 @@ +[CASE default_manager_create_is_typechecked] +from django.db import models + +class User(models.Model): + name = models.CharField(max_length=100) + age = models.IntegerField() + +User.objects.create(name='Max', age=10) +User.objects.create(name=1010) # E: Incompatible type for "name" of "User" (got "int", expected "Union[str, Combinable]") +[out] + +[CASE model_recognises_parent_attributes] +from django.db import models + +class Parent(models.Model): + name = models.CharField(max_length=100) +class Child(Parent): + lastname = models.CharField(max_length=100) +Child.objects.create(name='Maxim', lastname='Maxim2') +[out] + +[CASE deep_multiple_inheritance_with_create] +from django.db import models + +class Parent1(models.Model): + name1 = models.CharField(max_length=50) +class Parent2(models.Model): + id2 = models.AutoField(primary_key=True) + name2 = models.CharField(max_length=50) +class Child1(Parent1, Parent2): + value = models.IntegerField() +class Child4(Child1): + value4 = models.IntegerField() +Child4.objects.create(name1='n1', name2='n2', value=1, value4=4) \ No newline at end of file diff --git a/test-data/typecheck/model_init.test b/test-data/typecheck/model_init.test index bd6986a..29b2527 100644 --- a/test-data/typecheck/model_init.test +++ b/test-data/typecheck/model_init.test @@ -149,3 +149,9 @@ MyModel(field=time()) MyModel(field='12:00') MyModel(field=100) # E: Incompatible type for "field" of "MyModel" (got "int", expected "Union[str, time]") +[CASE charfield_with_integer_choices] +from django.db import models +class MyModel(models.Model): + day = models.CharField(max_length=3, choices=((1, 'Fri'), (2, 'Sat'))) +MyModel(day=1) +[out] diff --git a/test-data/typecheck/models.test b/test-data/typecheck/models.test deleted file mode 100644 index c29c497..0000000 --- a/test-data/typecheck/models.test +++ /dev/null @@ -1,49 +0,0 @@ -[CASE test_every_model_has_objects_queryset_available] -from django.db import models -class User(models.Model): - pass -reveal_type(User.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.User]' -reveal_type(User.objects.get()) # E: Revealed type is 'main.User*' - -[CASE test_leave_as_is_if_objects_is_set_and_fill_typevars_with_outer_class] -from django.db import models - -class UserManager(models.Manager[User]): - def get_or_404(self) -> User: - pass - -class User(models.Model): - objects = UserManager() - -reveal_type(User.objects) # E: Revealed type is 'main.UserManager' -reveal_type(User.objects.get()) # E: Revealed type is 'main.User*' -reveal_type(User.objects.get_or_404()) # E: Revealed type is 'main.User' - -[CASE test_model_objects_attribute_present_in_case_of_model_cls_passed_as_parameter] -from typing import Type -from django.db import models - -class Base: - def __init__(self, model_cls: Type[models.Model]): - self.model_cls = model_cls -class MyModel(models.Model): - pass -reveal_type(Base(MyModel).model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[django.db.models.base.Model]' - -[CASE test_model_objects_attribute_present_in_case_of_model_cls_passed_as_generic_parameter] -from typing import TypeVar, Generic, Type -from django.db import models - -_T = TypeVar('_T', bound=models.Model) -class Base(Generic[_T]): - def __init__(self, model_cls: Type[_T]): - self.model_cls = model_cls - reveal_type(self.model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[django.db.models.base.Model]' -class MyModel(models.Model): - pass -base_instance = Base(MyModel) -reveal_type(base_instance.model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]' - -class Child(Base[MyModel]): - def method(self) -> None: - reveal_type(self.model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]' \ No newline at end of file