diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index bca6989..6adc50d 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -209,17 +209,21 @@ class DjangoContext: return expected_types @cached_property - def model_base_classes(self) -> Set[str]: + def all_registered_model_classes(self) -> Set[Type[models.Model]]: model_classes = self.apps_registry.get_models() all_model_bases = set() for model_cls in model_classes: for base_cls in model_cls.mro(): if issubclass(base_cls, models.Model): - all_model_bases.add(helpers.get_class_fullname(base_cls)) + all_model_bases.add(base_cls) return all_model_bases + @cached_property + def all_registered_model_class_fullnames(self) -> Set[str]: + return {helpers.get_class_fullname(cls) for cls in self.all_registered_model_classes} + def get_attname(self, field: Field) -> str: attname = field.attname return attname diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 40f0a63..9689acb 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -282,7 +282,7 @@ def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionCont def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool: - return (info.fullname() in django_context.model_base_classes + return (info.fullname() in django_context.all_registered_model_class_fullnames or info.has_base(fullnames.MODEL_CLASS_FULLNAME)) @@ -292,3 +292,15 @@ def check_types_compatible(ctx: Union[FunctionContext, MethodContext], api.check_subtype(actual_type, expected_type, ctx.context, error_message, 'got', 'expected') + + +def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None: + # type=: type of the variable itself + var = Var(name=name, type=sym_type) + # var.info: type of the object variable is bound to + var.info = info + var._fullname = info.fullname() + '.' + name + var.is_initialized_in_class = True + var.is_inferred = True + info.names[name] = SymbolTableNode(MDEF, var, + plugin_generated=True) diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 424d8b0..5a39b85 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -219,7 +219,7 @@ class NewSemanalDjangoPlugin(Plugin): def get_base_class_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: - if (fullname in self.django_context.model_base_classes + if (fullname in self.django_context.all_registered_model_class_fullnames or fullname in self._get_current_model_bases()): return partial(transform_model_class, django_context=self.django_context) diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index a809cd2..32d20bf 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -37,6 +37,14 @@ def _get_current_field_from_assignment(ctx: FunctionContext, django_context: Dja return current_field +def reparametrize_related_field_type(related_field_type: Instance, set_type, get_type) -> Instance: + args = [ + helpers.convert_any_to_type(related_field_type.args[0], set_type), + helpers.convert_any_to_type(related_field_type.args[1], get_type), + ] + return helpers.reparametrize_instance(related_field_type, new_args=args) + + def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: current_field = _get_current_field_from_assignment(ctx, django_context) if current_field is None: @@ -48,6 +56,25 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context if related_model_cls is None: return AnyType(TypeOfAny.from_error) + default_related_field_type = set_descriptor_types_for_field(ctx) + + # self reference with abstract=True on the model where ForeignKey is defined + current_model_cls = current_field.model + if (current_model_cls._meta.abstract + and current_model_cls == related_model_cls): + # for all derived non-abstract classes, set variable with this name to + # __get__/__set__ of ForeignKey of derived model + for model_cls in django_context.all_registered_model_classes: + if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract: + derived_model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls) + if derived_model_info is not None: + fk_ref_type = Instance(derived_model_info, []) + derived_fk_type = reparametrize_related_field_type(default_related_field_type, + set_type=fk_ref_type, get_type=fk_ref_type) + helpers.add_new_sym_for_info(derived_model_info, + name=current_field.name, + sym_type=derived_fk_type) + related_model = related_model_cls related_model_to_set = related_model_cls if related_model_to_set._meta.proxy_for_model is not None: @@ -69,13 +96,10 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context else: related_model_to_set_type = Instance(related_model_to_set_info, []) # type: ignore - default_related_field_type = set_descriptor_types_for_field(ctx) # replace Any with referred_to_type - args = [ - helpers.convert_any_to_type(default_related_field_type.args[0], related_model_to_set_type), - helpers.convert_any_to_type(default_related_field_type.args[1], related_model_type), - ] - return helpers.reparametrize_instance(default_related_field_type, new_args=args) + return reparametrize_related_field_type(default_related_field_type, + set_type=related_model_to_set_type, + get_type=related_model_type) def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple[MypyType, MypyType]: diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 4e93ac4..195f565 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -7,9 +7,7 @@ from django.db.models.fields.related import ForeignKey from django.db.models.fields.reverse_related import ( ManyToManyRel, ManyToOneRel, OneToOneRel, ) -from mypy.nodes import ( - ARG_STAR2, MDEF, Argument, Context, SymbolTableNode, TypeInfo, Var, -) +from mypy.nodes import ARG_STAR2, Argument, Context, TypeInfo, Var from mypy.plugin import ClassDefContext from mypy.plugins import common from mypy.types import AnyType, Instance @@ -51,8 +49,9 @@ class ModelClassInitializer: return var def add_new_node_to_model_class(self, name: str, typ: MypyType) -> None: - var = self.create_new_var(name, typ) - self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True) + helpers.add_new_sym_for_info(self.model_classdef.info, + name=name, + sym_type=typ) def run(self) -> None: model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname) @@ -114,6 +113,9 @@ class AddRelatedModelsId(ModelClassInitializer): AnyType(TypeOfAny.explicit)) continue + if related_model_cls._meta.abstract: + continue + rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls) field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__) is_nullable = self.django_context.get_field_nullability(field, None) diff --git a/test-data/typecheck/fields/test_related.yml b/test-data/typecheck/fields/test_related.yml index bda03c2..b1e2670 100644 --- a/test-data/typecheck/fields/test_related.yml +++ b/test-data/typecheck/fields/test_related.yml @@ -623,4 +623,28 @@ class TransactionLog(models.Model): transaction = models.ForeignKey(Transaction, on_delete=models.CASCADE) - Transaction().test() \ No newline at end of file + Transaction().test() + + +- case: resolve_primary_keys_for_foreign_keys_with_abstract_self_model + main: | + from myapp.models import User + reveal_type(User().parent) # N: Revealed type is 'myapp.models.User*' + reveal_type(User().parent_id) # N: Revealed type is 'builtins.int*' + + reveal_type(User().parent2) # N: Revealed type is 'Union[myapp.models.User, None]' + reveal_type(User().parent2_id) # N: Revealed type is 'Union[builtins.int, None]' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class AbstractUser(models.Model): + parent = models.ForeignKey('self', on_delete=models.CASCADE) + parent2 = models.ForeignKey('self', null=True, on_delete=models.CASCADE) + class Meta: + abstract = True + class User(AbstractUser): + pass