add properly typed FOREIGN_KEY_FIELD_NAME_id fields to models

This commit is contained in:
Maxim Kurnikov
2019-02-13 21:05:02 +03:00
parent 82de0a8791
commit 26a80a8279
5 changed files with 153 additions and 61 deletions

View File

@@ -2,10 +2,10 @@ import typing
from typing import Dict, Optional from typing import Dict, Optional
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.nodes import AssignmentStmt, Expression, ImportedName, Lvalue, MypyFile, NameExpr, Statement, SymbolNode, TypeInfo, \ from mypy.nodes import AssignmentStmt, ClassDef, Expression, FuncDef, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, \
ClassDef TypeInfo
from mypy.plugin import FunctionContext from mypy.plugin import FunctionContext
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field' FIELD_FULLNAME = 'django.db.models.fields.Field'
@@ -172,7 +172,8 @@ def get_setting_expr(api: TypeChecker, setting_name: str) -> Optional[Expression
return None return None
def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]: def iter_over_assignments(
class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
if isinstance(class_or_module, ClassDef): if isinstance(class_or_module, ClassDef):
statements = class_or_module.defs.body statements = class_or_module.defs.body
else: else:
@@ -185,3 +186,71 @@ def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]) ->
# not supported yet # not supported yet
continue continue
yield stmt.lvalues[0], stmt.rvalue yield stmt.lvalues[0], stmt.rvalue
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
if tp.type.has_base(FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
set_value_type = fill_typevars(tp, set_value_type)
return set_value_type
elif isinstance(set_value_type, UnionType):
items_no_typevars = []
for item in set_value_type.items:
if isinstance(item, Instance):
item = fill_typevars(tp, item)
items_no_typevars.append(item)
return UnionType(items_no_typevars)
field_getter_type = extract_field_getter_type(tp)
if field_getter_type:
return field_getter_type
return None
def extract_field_getter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
if tp.type.has_base(FIELD_FULLNAME):
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
return get_method.type.ret_type
# GenericForeignKey
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
def get_django_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
return model.metadata.setdefault('django', {})
def get_related_field_primary_key_names(base_model: TypeInfo) -> typing.List[str]:
django_metadata = get_django_metadata(base_model)
return django_metadata.setdefault('related_field_primary_keys', [])
def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
return get_django_metadata(model).setdefault('fields', {})
def extract_primary_key_type_for_set(model: TypeInfo) -> Optional[Type]:
for field_name, props in get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
return extract_field_setter_type(model.names[field_name].type)
return None
def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]:
for field_name, props in get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
return extract_field_getter_type(model.names[field_name].type)
return None

View File

@@ -4,8 +4,8 @@ from typing import Callable, Dict, Optional, cast
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo from mypy.nodes import TypeInfo
from mypy.options import Options from mypy.options import Options
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin, AttributeContext
from mypy.types import Instance, Type, TypeType from mypy.types import Instance, Type, TypeType, AnyType, TypeOfAny
from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.config import Config from mypy_django_plugin.config import Config
@@ -85,6 +85,27 @@ def return_user_model_hook(ctx: FunctionContext) -> Type:
return TypeType(Instance(model_info, [])) return TypeType(Instance(model_info, []))
def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type:
if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'):
return ctx.default_attr_type
if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME):
return ctx.default_attr_type
field_name = ctx.context.name.split('_')[0]
sym = ctx.type.type.get(field_name)
if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0:
to_arg = sym.type.args[0]
if isinstance(to_arg, AnyType):
return AnyType(TypeOfAny.special_form)
model_type: TypeInfo = to_arg.type
primary_key_type = helpers.extract_primary_key_type_for_get(model_type)
if primary_key_type:
return primary_key_type
return ctx.default_attr_type
class DjangoPlugin(Plugin): class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None: def __init__(self, options: Options) -> None:
super().__init__(options) super().__init__(options)
@@ -186,6 +207,14 @@ class DjangoPlugin(Plugin):
return None return None
def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
# sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME)
# if sym and isinstance(sym.node, TypeInfo):
# if fullname.rpartition('.')[-1] in helpers.get_related_field_primary_key_names(sym.node):
return extract_and_return_primary_key_of_bound_related_field_parameter
def plugin(version): def plugin(version):
return DjangoPlugin return DjangoPlugin

View File

@@ -1,11 +1,12 @@
from typing import Any, Dict, Optional, Set, cast from typing import Dict, Optional, Set, cast
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.nodes import FuncDef, TypeInfo, Var from mypy.nodes import TypeInfo, Var
from mypy.plugin import FunctionContext, MethodContext from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, UnionType from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType
from mypy_django_plugin import helpers from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import extract_field_setter_type, extract_primary_key_type_for_set, get_fields_metadata
def extract_base_pointer_args(model: TypeInfo) -> Set[str]: def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
@@ -103,46 +104,6 @@ def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type:
return ctx.default_return_type return ctx.default_return_type
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
if tp.type.has_base(helpers.FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
set_value_type = helpers.fill_typevars(tp, set_value_type)
return set_value_type
elif isinstance(set_value_type, UnionType):
items_no_typevars = []
for item in set_value_type.items:
if isinstance(item, Instance):
item = helpers.fill_typevars(tp, item)
items_no_typevars.append(item)
return UnionType(items_no_typevars)
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
return get_method.type.ret_type
# GenericForeignKey
if tp.type.has_base(helpers.GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]:
return model.metadata.setdefault('django', {}).setdefault('fields', {})
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
for field_name, props in get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
return extract_field_setter_type(model.names[field_name].type)
return None
def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]: def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
field_metadata = get_fields_metadata(model).get(field_name, {}) field_metadata = get_fields_metadata(model).get(field_name, {})
if 'choices' in field_metadata: if 'choices' in field_metadata:
@@ -153,7 +114,7 @@ def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]: def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
expected_types: Dict[str, Type] = {} expected_types: Dict[str, Type] = {}
primary_key_type = extract_primary_key_type(model) primary_key_type = extract_primary_key_type_for_set(model)
if not primary_key_type: if not primary_key_type:
# no explicit primary key, set pk to Any and add id # no explicit primary key, set pk to Any and add id
primary_key_type = AnyType(TypeOfAny.special_form) primary_key_type = AnyType(TypeOfAny.special_form)
@@ -178,11 +139,13 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, T
if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}: if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
ref_to_model = tp.args[0] ref_to_model = tp.args[0]
primary_key_type = AnyType(TypeOfAny.special_form)
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME): if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
primary_key_type = extract_primary_key_type(ref_to_model.type) typ = extract_primary_key_type_for_set(ref_to_model.type)
if not primary_key_type: if typ:
primary_key_type = AnyType(TypeOfAny.special_form) primary_key_type = typ
expected_types[name + '_id'] = primary_key_type expected_types[name + '_id'] = primary_key_type
if field_type: if field_type:
expected_types[name] = field_type expected_types[name] = field_type
elif isinstance(sym.node.type, AnyType): elif isinstance(sym.node.type, AnyType):

View File

@@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod
from typing import Dict, Iterator, List, Optional, Tuple, cast from typing import Dict, Iterator, List, Optional, Tuple, cast
import dataclasses import dataclasses
from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, AssignmentStmt, CallExpr, ClassDef, Context, Expression, IndexExpr, \ from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, CallExpr, ClassDef, Context, Expression, IndexExpr, \
Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.plugins.common import add_method from mypy.plugins.common import add_method
@@ -74,8 +74,11 @@ def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExp
class SetIdAttrsForRelatedFields(ModelClassInitializer): class SetIdAttrsForRelatedFields(ModelClassInitializer):
def run(self) -> None: def run(self) -> None:
for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef): for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef):
self.add_new_node_to_model_class(lvalue.name + '_id', # base_model_info = self.api.named_type('builtins.object').type
typ=self.api.named_type('__builtins__.int')) # helpers.get_related_field_primary_key_names(base_model_info).append(node_name)
node_name = lvalue.name + '_id'
self.add_new_node_to_model_class(name=node_name,
typ=self.api.builtin_type('builtins.int'))
class InjectAnyAsBaseForNestedMeta(ModelClassInitializer): class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):

View File

@@ -22,13 +22,11 @@ class Publisher(models.Model):
class Book(models.Model): class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
class StylesheetError(Exception):
pass
owner = models.ForeignKey(db_column='model_id', to='db.Unknown', on_delete=models.CASCADE) owner = models.ForeignKey(db_column='model_id', to='db.Unknown', on_delete=models.CASCADE)
book = Book() book = Book()
reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int' reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
reveal_type(book.owner_id) # E: Revealed type is 'builtins.int' reveal_type(book.owner_id) # E: Revealed type is 'Any'
[CASE test_foreign_key_field_different_order_of_params] [CASE test_foreign_key_field_different_order_of_params]
from django.db import models from django.db import models
@@ -68,7 +66,7 @@ from django.db import models
class Publisher(models.Model): class Publisher(models.Model):
pass pass
[CASE test_to_parameter_as_string_with_application_name__fallbacks_to_any_if_model_not_present_in_dependency_graph] [CASE test_to_parameter_as_string_with_application_name_fallbacks_to_any_if_model_not_present_in_dependency_graph]
from django.db import models from django.db import models
class Book(models.Model): class Book(models.Model):
@@ -76,6 +74,9 @@ class Book(models.Model):
book = Book() book = Book()
reveal_type(book.publisher) # E: Revealed type is 'Any' reveal_type(book.publisher) # E: Revealed type is 'Any'
reveal_type(book.publisher_id) # E: Revealed type is 'Any'
Book(publisher_id=1)
Book.objects.create(publisher_id=1)
[file myapp/__init__.py] [file myapp/__init__.py]
[file myapp/models.py] [file myapp/models.py]
@@ -247,3 +248,30 @@ class Publisher(models.Model):
class Book(models.Model): class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
reveal_type(Publisher().book_set) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Book]' reveal_type(Publisher().book_set) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Book]'
[CASE underscore_id_attribute_has_set_type_of_primary_key_if_explicit]
from django.db import models
import datetime
class Publisher(models.Model):
mypk = models.CharField(max_length=100, primary_key=True)
class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
reveal_type(Book().publisher_id) # E: Revealed type is 'builtins.str'
Book(publisher_id=1)
Book(publisher_id='hello')
Book(publisher_id=datetime.datetime.now()) # E: Incompatible type for "publisher_id" of "Book" (got "datetime", expected "Union[str, int, Combinable]")
Book.objects.create(publisher_id=1)
Book.objects.create(publisher_id='hello')
class Publisher2(models.Model):
mypk = models.IntegerField(primary_key=True)
class Book2(models.Model):
publisher = models.ForeignKey(to=Publisher2, on_delete=models.CASCADE)
reveal_type(Book2().publisher_id) # E: Revealed type is 'builtins.int'
Book2(publisher_id=1)
Book2(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]")
Book2.objects.create(publisher_id=1)
Book2.objects.create(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]")
[out]