add properly typed FOREIGN_KEY_FIELD_NAME_id fields to models

This commit is contained in:
Maxim Kurnikov
2019-02-13 21:05:02 +03:00
parent 82de0a8791
commit 26a80a8279
5 changed files with 153 additions and 61 deletions

View File

@@ -2,10 +2,10 @@ import typing
from typing import Dict, Optional
from mypy.checker import TypeChecker
from mypy.nodes import AssignmentStmt, Expression, ImportedName, Lvalue, MypyFile, NameExpr, Statement, SymbolNode, TypeInfo, \
ClassDef
from mypy.nodes import AssignmentStmt, ClassDef, Expression, FuncDef, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, \
TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field'
@@ -172,7 +172,8 @@ def get_setting_expr(api: TypeChecker, setting_name: str) -> Optional[Expression
return None
def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
def iter_over_assignments(
class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
if isinstance(class_or_module, ClassDef):
statements = class_or_module.defs.body
else:
@@ -185,3 +186,71 @@ def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]) ->
# not supported yet
continue
yield stmt.lvalues[0], stmt.rvalue
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
if tp.type.has_base(FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
set_value_type = fill_typevars(tp, set_value_type)
return set_value_type
elif isinstance(set_value_type, UnionType):
items_no_typevars = []
for item in set_value_type.items:
if isinstance(item, Instance):
item = fill_typevars(tp, item)
items_no_typevars.append(item)
return UnionType(items_no_typevars)
field_getter_type = extract_field_getter_type(tp)
if field_getter_type:
return field_getter_type
return None
def extract_field_getter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
if tp.type.has_base(FIELD_FULLNAME):
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
return get_method.type.ret_type
# GenericForeignKey
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
def get_django_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
return model.metadata.setdefault('django', {})
def get_related_field_primary_key_names(base_model: TypeInfo) -> typing.List[str]:
django_metadata = get_django_metadata(base_model)
return django_metadata.setdefault('related_field_primary_keys', [])
def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
return get_django_metadata(model).setdefault('fields', {})
def extract_primary_key_type_for_set(model: TypeInfo) -> Optional[Type]:
for field_name, props in get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
return extract_field_setter_type(model.names[field_name].type)
return None
def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]:
for field_name, props in get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
return extract_field_getter_type(model.names[field_name].type)
return None

View File

@@ -4,8 +4,8 @@ from typing import Callable, Dict, Optional, cast
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo
from mypy.options import Options
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import Instance, Type, TypeType
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin, AttributeContext
from mypy.types import Instance, Type, TypeType, AnyType, TypeOfAny
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.config import Config
@@ -85,6 +85,27 @@ def return_user_model_hook(ctx: FunctionContext) -> Type:
return TypeType(Instance(model_info, []))
def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type:
if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'):
return ctx.default_attr_type
if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME):
return ctx.default_attr_type
field_name = ctx.context.name.split('_')[0]
sym = ctx.type.type.get(field_name)
if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0:
to_arg = sym.type.args[0]
if isinstance(to_arg, AnyType):
return AnyType(TypeOfAny.special_form)
model_type: TypeInfo = to_arg.type
primary_key_type = helpers.extract_primary_key_type_for_get(model_type)
if primary_key_type:
return primary_key_type
return ctx.default_attr_type
class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
@@ -186,6 +207,14 @@ class DjangoPlugin(Plugin):
return None
def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
# sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME)
# if sym and isinstance(sym.node, TypeInfo):
# if fullname.rpartition('.')[-1] in helpers.get_related_field_primary_key_names(sym.node):
return extract_and_return_primary_key_of_bound_related_field_parameter
def plugin(version):
return DjangoPlugin

View File

@@ -1,11 +1,12 @@
from typing import Any, Dict, Optional, Set, cast
from typing import Dict, Optional, Set, cast
from mypy.checker import TypeChecker
from mypy.nodes import FuncDef, TypeInfo, Var
from mypy.nodes import TypeInfo, Var
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, UnionType
from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType
from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import extract_field_setter_type, extract_primary_key_type_for_set, get_fields_metadata
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
@@ -103,46 +104,6 @@ def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type:
return ctx.default_return_type
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
if tp.type.has_base(helpers.FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
set_value_type = helpers.fill_typevars(tp, set_value_type)
return set_value_type
elif isinstance(set_value_type, UnionType):
items_no_typevars = []
for item in set_value_type.items:
if isinstance(item, Instance):
item = helpers.fill_typevars(tp, item)
items_no_typevars.append(item)
return UnionType(items_no_typevars)
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
return get_method.type.ret_type
# GenericForeignKey
if tp.type.has_base(helpers.GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]:
return model.metadata.setdefault('django', {}).setdefault('fields', {})
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
for field_name, props in get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
return extract_field_setter_type(model.names[field_name].type)
return None
def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
field_metadata = get_fields_metadata(model).get(field_name, {})
if 'choices' in field_metadata:
@@ -153,7 +114,7 @@ def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
expected_types: Dict[str, Type] = {}
primary_key_type = extract_primary_key_type(model)
primary_key_type = extract_primary_key_type_for_set(model)
if not primary_key_type:
# no explicit primary key, set pk to Any and add id
primary_key_type = AnyType(TypeOfAny.special_form)
@@ -178,11 +139,13 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, T
if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
ref_to_model = tp.args[0]
primary_key_type = AnyType(TypeOfAny.special_form)
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
primary_key_type = extract_primary_key_type(ref_to_model.type)
if not primary_key_type:
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types[name + '_id'] = primary_key_type
typ = extract_primary_key_type_for_set(ref_to_model.type)
if typ:
primary_key_type = typ
expected_types[name + '_id'] = primary_key_type
if field_type:
expected_types[name] = field_type
elif isinstance(sym.node.type, AnyType):

View File

@@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod
from typing import Dict, Iterator, List, Optional, Tuple, cast
import dataclasses
from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, AssignmentStmt, CallExpr, ClassDef, Context, Expression, IndexExpr, \
from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, CallExpr, ClassDef, Context, Expression, IndexExpr, \
Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext
from mypy.plugins.common import add_method
@@ -74,8 +74,11 @@ def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExp
class SetIdAttrsForRelatedFields(ModelClassInitializer):
def run(self) -> None:
for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef):
self.add_new_node_to_model_class(lvalue.name + '_id',
typ=self.api.named_type('__builtins__.int'))
# base_model_info = self.api.named_type('builtins.object').type
# helpers.get_related_field_primary_key_names(base_model_info).append(node_name)
node_name = lvalue.name + '_id'
self.add_new_node_to_model_class(name=node_name,
typ=self.api.builtin_type('builtins.int'))
class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):

View File

@@ -22,13 +22,11 @@ class Publisher(models.Model):
class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
class StylesheetError(Exception):
pass
owner = models.ForeignKey(db_column='model_id', to='db.Unknown', on_delete=models.CASCADE)
book = Book()
reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
reveal_type(book.owner_id) # E: Revealed type is 'builtins.int'
reveal_type(book.owner_id) # E: Revealed type is 'Any'
[CASE test_foreign_key_field_different_order_of_params]
from django.db import models
@@ -68,7 +66,7 @@ from django.db import models
class Publisher(models.Model):
pass
[CASE test_to_parameter_as_string_with_application_name__fallbacks_to_any_if_model_not_present_in_dependency_graph]
[CASE test_to_parameter_as_string_with_application_name_fallbacks_to_any_if_model_not_present_in_dependency_graph]
from django.db import models
class Book(models.Model):
@@ -76,6 +74,9 @@ class Book(models.Model):
book = Book()
reveal_type(book.publisher) # E: Revealed type is 'Any'
reveal_type(book.publisher_id) # E: Revealed type is 'Any'
Book(publisher_id=1)
Book.objects.create(publisher_id=1)
[file myapp/__init__.py]
[file myapp/models.py]
@@ -247,3 +248,30 @@ class Publisher(models.Model):
class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
reveal_type(Publisher().book_set) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Book]'
[CASE underscore_id_attribute_has_set_type_of_primary_key_if_explicit]
from django.db import models
import datetime
class Publisher(models.Model):
mypk = models.CharField(max_length=100, primary_key=True)
class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
reveal_type(Book().publisher_id) # E: Revealed type is 'builtins.str'
Book(publisher_id=1)
Book(publisher_id='hello')
Book(publisher_id=datetime.datetime.now()) # E: Incompatible type for "publisher_id" of "Book" (got "datetime", expected "Union[str, int, Combinable]")
Book.objects.create(publisher_id=1)
Book.objects.create(publisher_id='hello')
class Publisher2(models.Model):
mypk = models.IntegerField(primary_key=True)
class Book2(models.Model):
publisher = models.ForeignKey(to=Publisher2, on_delete=models.CASCADE)
reveal_type(Book2().publisher_id) # E: Revealed type is 'builtins.int'
Book2(publisher_id=1)
Book2(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]")
Book2.objects.create(publisher_id=1)
Book2.objects.create(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]")
[out]