From 26a80a82796a563d60c384fd565d237871ce2e8d Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Wed, 13 Feb 2019 21:05:02 +0300 Subject: [PATCH] add properly typed FOREIGN_KEY_FIELD_NAME_id fields to models --- mypy_django_plugin/helpers.py | 77 +++++++++++++++++++++-- mypy_django_plugin/main.py | 33 +++++++++- mypy_django_plugin/plugins/init_create.py | 59 ++++------------- mypy_django_plugin/plugins/models.py | 9 ++- test-data/typecheck/related_fields.test | 36 +++++++++-- 5 files changed, 153 insertions(+), 61 deletions(-) diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index c451e2c..0c28f32 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -2,10 +2,10 @@ import typing from typing import Dict, Optional from mypy.checker import TypeChecker -from mypy.nodes import AssignmentStmt, Expression, ImportedName, Lvalue, MypyFile, NameExpr, Statement, SymbolNode, TypeInfo, \ - ClassDef +from mypy.nodes import AssignmentStmt, ClassDef, Expression, FuncDef, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, \ + TypeInfo from mypy.plugin import FunctionContext -from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType +from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' FIELD_FULLNAME = 'django.db.models.fields.Field' @@ -172,7 +172,8 @@ def get_setting_expr(api: TypeChecker, setting_name: str) -> Optional[Expression return None -def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]: +def iter_over_assignments( + class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]: if isinstance(class_or_module, ClassDef): statements = class_or_module.defs.body else: @@ -185,3 +186,71 @@ def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]) -> # not supported yet continue yield stmt.lvalues[0], stmt.rvalue + + +def extract_field_setter_type(tp: Instance) -> Optional[Type]: + if not isinstance(tp, Instance): + return None + if tp.type.has_base(FIELD_FULLNAME): + set_method = tp.type.get_method('__set__') + if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType): + if 'value' in set_method.type.arg_names: + set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')] + if isinstance(set_value_type, Instance): + set_value_type = fill_typevars(tp, set_value_type) + return set_value_type + elif isinstance(set_value_type, UnionType): + items_no_typevars = [] + for item in set_value_type.items: + if isinstance(item, Instance): + item = fill_typevars(tp, item) + items_no_typevars.append(item) + return UnionType(items_no_typevars) + + field_getter_type = extract_field_getter_type(tp) + if field_getter_type: + return field_getter_type + + return None + + +def extract_field_getter_type(tp: Instance) -> Optional[Type]: + if not isinstance(tp, Instance): + return None + if tp.type.has_base(FIELD_FULLNAME): + get_method = tp.type.get_method('__get__') + if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType): + return get_method.type.ret_type + # GenericForeignKey + if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME): + return AnyType(TypeOfAny.special_form) + return None + + +def get_django_metadata(model: TypeInfo) -> Dict[str, typing.Any]: + return model.metadata.setdefault('django', {}) + + +def get_related_field_primary_key_names(base_model: TypeInfo) -> typing.List[str]: + django_metadata = get_django_metadata(base_model) + return django_metadata.setdefault('related_field_primary_keys', []) + + +def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]: + return get_django_metadata(model).setdefault('fields', {}) + + +def extract_primary_key_type_for_set(model: TypeInfo) -> Optional[Type]: + for field_name, props in get_fields_metadata(model).items(): + is_primary_key = props.get('primary_key', False) + if is_primary_key: + return extract_field_setter_type(model.names[field_name].type) + return None + + +def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]: + for field_name, props in get_fields_metadata(model).items(): + is_primary_key = props.get('primary_key', False) + if is_primary_key: + return extract_field_getter_type(model.names[field_name].type) + return None diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 3697546..a21fa6d 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -4,8 +4,8 @@ from typing import Callable, Dict, Optional, cast from mypy.checker import TypeChecker from mypy.nodes import TypeInfo from mypy.options import Options -from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin -from mypy.types import Instance, Type, TypeType +from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin, AttributeContext +from mypy.types import Instance, Type, TypeType, AnyType, TypeOfAny from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin.config import Config @@ -85,6 +85,27 @@ def return_user_model_hook(ctx: FunctionContext) -> Type: return TypeType(Instance(model_info, [])) +def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type: + if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'): + return ctx.default_attr_type + + if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME): + return ctx.default_attr_type + + field_name = ctx.context.name.split('_')[0] + sym = ctx.type.type.get(field_name) + if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0: + to_arg = sym.type.args[0] + if isinstance(to_arg, AnyType): + return AnyType(TypeOfAny.special_form) + + model_type: TypeInfo = to_arg.type + primary_key_type = helpers.extract_primary_key_type_for_get(model_type) + if primary_key_type: + return primary_key_type + return ctx.default_attr_type + + class DjangoPlugin(Plugin): def __init__(self, options: Options) -> None: super().__init__(options) @@ -186,6 +207,14 @@ class DjangoPlugin(Plugin): return None + def get_attribute_hook(self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: + # sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME) + # if sym and isinstance(sym.node, TypeInfo): + # if fullname.rpartition('.')[-1] in helpers.get_related_field_primary_key_names(sym.node): + return extract_and_return_primary_key_of_bound_related_field_parameter + + def plugin(version): return DjangoPlugin diff --git a/mypy_django_plugin/plugins/init_create.py b/mypy_django_plugin/plugins/init_create.py index 3066249..ee9cc6a 100644 --- a/mypy_django_plugin/plugins/init_create.py +++ b/mypy_django_plugin/plugins/init_create.py @@ -1,11 +1,12 @@ -from typing import Any, Dict, Optional, Set, cast +from typing import Dict, Optional, Set, cast from mypy.checker import TypeChecker -from mypy.nodes import FuncDef, TypeInfo, Var +from mypy.nodes import TypeInfo, Var from mypy.plugin import FunctionContext, MethodContext -from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, UnionType +from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType from mypy_django_plugin import helpers +from mypy_django_plugin.helpers import extract_field_setter_type, extract_primary_key_type_for_set, get_fields_metadata def extract_base_pointer_args(model: TypeInfo) -> Set[str]: @@ -103,46 +104,6 @@ def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type: return ctx.default_return_type -def extract_field_setter_type(tp: Instance) -> Optional[Type]: - if not isinstance(tp, Instance): - return None - if tp.type.has_base(helpers.FIELD_FULLNAME): - set_method = tp.type.get_method('__set__') - if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType): - if 'value' in set_method.type.arg_names: - set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')] - if isinstance(set_value_type, Instance): - set_value_type = helpers.fill_typevars(tp, set_value_type) - return set_value_type - elif isinstance(set_value_type, UnionType): - items_no_typevars = [] - for item in set_value_type.items: - if isinstance(item, Instance): - item = helpers.fill_typevars(tp, item) - items_no_typevars.append(item) - return UnionType(items_no_typevars) - - get_method = tp.type.get_method('__get__') - if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType): - return get_method.type.ret_type - # GenericForeignKey - if tp.type.has_base(helpers.GENERIC_FOREIGN_KEY_FULLNAME): - return AnyType(TypeOfAny.special_form) - return None - - -def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]: - return model.metadata.setdefault('django', {}).setdefault('fields', {}) - - -def extract_primary_key_type(model: TypeInfo) -> Optional[Type]: - for field_name, props in get_fields_metadata(model).items(): - is_primary_key = props.get('primary_key', False) - if is_primary_key: - return extract_field_setter_type(model.names[field_name].type) - return None - - def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]: field_metadata = get_fields_metadata(model).get(field_name, {}) if 'choices' in field_metadata: @@ -153,7 +114,7 @@ def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]: def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]: expected_types: Dict[str, Type] = {} - primary_key_type = extract_primary_key_type(model) + primary_key_type = extract_primary_key_type_for_set(model) if not primary_key_type: # no explicit primary key, set pk to Any and add id primary_key_type = AnyType(TypeOfAny.special_form) @@ -178,11 +139,13 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, T if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}: ref_to_model = tp.args[0] + primary_key_type = AnyType(TypeOfAny.special_form) if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME): - primary_key_type = extract_primary_key_type(ref_to_model.type) - if not primary_key_type: - primary_key_type = AnyType(TypeOfAny.special_form) - expected_types[name + '_id'] = primary_key_type + typ = extract_primary_key_type_for_set(ref_to_model.type) + if typ: + primary_key_type = typ + expected_types[name + '_id'] = primary_key_type + if field_type: expected_types[name] = field_type elif isinstance(sym.node.type, AnyType): diff --git a/mypy_django_plugin/plugins/models.py b/mypy_django_plugin/plugins/models.py index b849c22..2e354a2 100644 --- a/mypy_django_plugin/plugins/models.py +++ b/mypy_django_plugin/plugins/models.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod from typing import Dict, Iterator, List, Optional, Tuple, cast import dataclasses -from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, AssignmentStmt, CallExpr, ClassDef, Context, Expression, IndexExpr, \ +from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, CallExpr, ClassDef, Context, Expression, IndexExpr, \ Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var from mypy.plugin import ClassDefContext from mypy.plugins.common import add_method @@ -74,8 +74,11 @@ def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExp class SetIdAttrsForRelatedFields(ModelClassInitializer): def run(self) -> None: for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef): - self.add_new_node_to_model_class(lvalue.name + '_id', - typ=self.api.named_type('__builtins__.int')) + # base_model_info = self.api.named_type('builtins.object').type + # helpers.get_related_field_primary_key_names(base_model_info).append(node_name) + node_name = lvalue.name + '_id' + self.add_new_node_to_model_class(name=node_name, + typ=self.api.builtin_type('builtins.int')) class InjectAnyAsBaseForNestedMeta(ModelClassInitializer): diff --git a/test-data/typecheck/related_fields.test b/test-data/typecheck/related_fields.test index 6b5dde8..f36d286 100644 --- a/test-data/typecheck/related_fields.test +++ b/test-data/typecheck/related_fields.test @@ -22,13 +22,11 @@ class Publisher(models.Model): class Book(models.Model): publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) - class StylesheetError(Exception): - pass owner = models.ForeignKey(db_column='model_id', to='db.Unknown', on_delete=models.CASCADE) book = Book() reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int' -reveal_type(book.owner_id) # E: Revealed type is 'builtins.int' +reveal_type(book.owner_id) # E: Revealed type is 'Any' [CASE test_foreign_key_field_different_order_of_params] from django.db import models @@ -68,7 +66,7 @@ from django.db import models class Publisher(models.Model): pass -[CASE test_to_parameter_as_string_with_application_name__fallbacks_to_any_if_model_not_present_in_dependency_graph] +[CASE test_to_parameter_as_string_with_application_name_fallbacks_to_any_if_model_not_present_in_dependency_graph] from django.db import models class Book(models.Model): @@ -76,6 +74,9 @@ class Book(models.Model): book = Book() reveal_type(book.publisher) # E: Revealed type is 'Any' +reveal_type(book.publisher_id) # E: Revealed type is 'Any' +Book(publisher_id=1) +Book.objects.create(publisher_id=1) [file myapp/__init__.py] [file myapp/models.py] @@ -247,3 +248,30 @@ class Publisher(models.Model): class Book(models.Model): publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) reveal_type(Publisher().book_set) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Book]' + +[CASE underscore_id_attribute_has_set_type_of_primary_key_if_explicit] +from django.db import models +import datetime +class Publisher(models.Model): + mypk = models.CharField(max_length=100, primary_key=True) +class Book(models.Model): + publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) + +reveal_type(Book().publisher_id) # E: Revealed type is 'builtins.str' +Book(publisher_id=1) +Book(publisher_id='hello') +Book(publisher_id=datetime.datetime.now()) # E: Incompatible type for "publisher_id" of "Book" (got "datetime", expected "Union[str, int, Combinable]") +Book.objects.create(publisher_id=1) +Book.objects.create(publisher_id='hello') + +class Publisher2(models.Model): + mypk = models.IntegerField(primary_key=True) +class Book2(models.Model): + publisher = models.ForeignKey(to=Publisher2, on_delete=models.CASCADE) + +reveal_type(Book2().publisher_id) # E: Revealed type is 'builtins.int' +Book2(publisher_id=1) +Book2(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]") +Book2.objects.create(publisher_id=1) +Book2.objects.create(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]") +[out]