Fix ForeignKey type for self-reference defined in the abstract model (#200)

This commit is contained in:
Maxim Kurnikov
2019-10-05 21:36:29 +03:00
committed by GitHub
parent db9ff6aaf6
commit 7e3f4bfa02
6 changed files with 82 additions and 16 deletions

View File

@@ -209,17 +209,21 @@ class DjangoContext:
return expected_types return expected_types
@cached_property @cached_property
def model_base_classes(self) -> Set[str]: def all_registered_model_classes(self) -> Set[Type[models.Model]]:
model_classes = self.apps_registry.get_models() model_classes = self.apps_registry.get_models()
all_model_bases = set() all_model_bases = set()
for model_cls in model_classes: for model_cls in model_classes:
for base_cls in model_cls.mro(): for base_cls in model_cls.mro():
if issubclass(base_cls, models.Model): if issubclass(base_cls, models.Model):
all_model_bases.add(helpers.get_class_fullname(base_cls)) all_model_bases.add(base_cls)
return all_model_bases return all_model_bases
@cached_property
def all_registered_model_class_fullnames(self) -> Set[str]:
return {helpers.get_class_fullname(cls) for cls in self.all_registered_model_classes}
def get_attname(self, field: Field) -> str: def get_attname(self, field: Field) -> str:
attname = field.attname attname = field.attname
return attname return attname

View File

@@ -282,7 +282,7 @@ def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionCont
def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool: def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool:
return (info.fullname() in django_context.model_base_classes return (info.fullname() in django_context.all_registered_model_class_fullnames
or info.has_base(fullnames.MODEL_CLASS_FULLNAME)) or info.has_base(fullnames.MODEL_CLASS_FULLNAME))
@@ -292,3 +292,15 @@ def check_types_compatible(ctx: Union[FunctionContext, MethodContext],
api.check_subtype(actual_type, expected_type, api.check_subtype(actual_type, expected_type,
ctx.context, error_message, ctx.context, error_message,
'got', 'expected') 'got', 'expected')
def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None:
# type=: type of the variable itself
var = Var(name=name, type=sym_type)
# var.info: type of the object variable is bound to
var.info = info
var._fullname = info.fullname() + '.' + name
var.is_initialized_in_class = True
var.is_inferred = True
info.names[name] = SymbolTableNode(MDEF, var,
plugin_generated=True)

View File

@@ -219,7 +219,7 @@ class NewSemanalDjangoPlugin(Plugin):
def get_base_class_hook(self, fullname: str def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]: ) -> Optional[Callable[[ClassDefContext], None]]:
if (fullname in self.django_context.model_base_classes if (fullname in self.django_context.all_registered_model_class_fullnames
or fullname in self._get_current_model_bases()): or fullname in self._get_current_model_bases()):
return partial(transform_model_class, django_context=self.django_context) return partial(transform_model_class, django_context=self.django_context)

View File

@@ -37,6 +37,14 @@ def _get_current_field_from_assignment(ctx: FunctionContext, django_context: Dja
return current_field return current_field
def reparametrize_related_field_type(related_field_type: Instance, set_type, get_type) -> Instance:
args = [
helpers.convert_any_to_type(related_field_type.args[0], set_type),
helpers.convert_any_to_type(related_field_type.args[1], get_type),
]
return helpers.reparametrize_instance(related_field_type, new_args=args)
def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
current_field = _get_current_field_from_assignment(ctx, django_context) current_field = _get_current_field_from_assignment(ctx, django_context)
if current_field is None: if current_field is None:
@@ -48,6 +56,25 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
if related_model_cls is None: if related_model_cls is None:
return AnyType(TypeOfAny.from_error) return AnyType(TypeOfAny.from_error)
default_related_field_type = set_descriptor_types_for_field(ctx)
# self reference with abstract=True on the model where ForeignKey is defined
current_model_cls = current_field.model
if (current_model_cls._meta.abstract
and current_model_cls == related_model_cls):
# for all derived non-abstract classes, set variable with this name to
# __get__/__set__ of ForeignKey of derived model
for model_cls in django_context.all_registered_model_classes:
if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract:
derived_model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls)
if derived_model_info is not None:
fk_ref_type = Instance(derived_model_info, [])
derived_fk_type = reparametrize_related_field_type(default_related_field_type,
set_type=fk_ref_type, get_type=fk_ref_type)
helpers.add_new_sym_for_info(derived_model_info,
name=current_field.name,
sym_type=derived_fk_type)
related_model = related_model_cls related_model = related_model_cls
related_model_to_set = related_model_cls related_model_to_set = related_model_cls
if related_model_to_set._meta.proxy_for_model is not None: if related_model_to_set._meta.proxy_for_model is not None:
@@ -69,13 +96,10 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
else: else:
related_model_to_set_type = Instance(related_model_to_set_info, []) # type: ignore related_model_to_set_type = Instance(related_model_to_set_info, []) # type: ignore
default_related_field_type = set_descriptor_types_for_field(ctx)
# replace Any with referred_to_type # replace Any with referred_to_type
args = [ return reparametrize_related_field_type(default_related_field_type,
helpers.convert_any_to_type(default_related_field_type.args[0], related_model_to_set_type), set_type=related_model_to_set_type,
helpers.convert_any_to_type(default_related_field_type.args[1], related_model_type), get_type=related_model_type)
]
return helpers.reparametrize_instance(default_related_field_type, new_args=args)
def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple[MypyType, MypyType]: def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple[MypyType, MypyType]:

View File

@@ -7,9 +7,7 @@ from django.db.models.fields.related import ForeignKey
from django.db.models.fields.reverse_related import ( from django.db.models.fields.reverse_related import (
ManyToManyRel, ManyToOneRel, OneToOneRel, ManyToManyRel, ManyToOneRel, OneToOneRel,
) )
from mypy.nodes import ( from mypy.nodes import ARG_STAR2, Argument, Context, TypeInfo, Var
ARG_STAR2, MDEF, Argument, Context, SymbolTableNode, TypeInfo, Var,
)
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.plugins import common from mypy.plugins import common
from mypy.types import AnyType, Instance from mypy.types import AnyType, Instance
@@ -51,8 +49,9 @@ class ModelClassInitializer:
return var return var
def add_new_node_to_model_class(self, name: str, typ: MypyType) -> None: def add_new_node_to_model_class(self, name: str, typ: MypyType) -> None:
var = self.create_new_var(name, typ) helpers.add_new_sym_for_info(self.model_classdef.info,
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True) name=name,
sym_type=typ)
def run(self) -> None: def run(self) -> None:
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname) model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname)
@@ -114,6 +113,9 @@ class AddRelatedModelsId(ModelClassInitializer):
AnyType(TypeOfAny.explicit)) AnyType(TypeOfAny.explicit))
continue continue
if related_model_cls._meta.abstract:
continue
rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls) rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__) field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__)
is_nullable = self.django_context.get_field_nullability(field, None) is_nullable = self.django_context.get_field_nullability(field, None)

View File

@@ -624,3 +624,27 @@
transaction = models.ForeignKey(Transaction, on_delete=models.CASCADE) transaction = models.ForeignKey(Transaction, on_delete=models.CASCADE)
Transaction().test() Transaction().test()
- case: resolve_primary_keys_for_foreign_keys_with_abstract_self_model
main: |
from myapp.models import User
reveal_type(User().parent) # N: Revealed type is 'myapp.models.User*'
reveal_type(User().parent_id) # N: Revealed type is 'builtins.int*'
reveal_type(User().parent2) # N: Revealed type is 'Union[myapp.models.User, None]'
reveal_type(User().parent2_id) # N: Revealed type is 'Union[builtins.int, None]'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class AbstractUser(models.Model):
parent = models.ForeignKey('self', on_delete=models.CASCADE)
parent2 = models.ForeignKey('self', null=True, on_delete=models.CASCADE)
class Meta:
abstract = True
class User(AbstractUser):
pass