diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index 3170696..f56f179 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -6,6 +6,7 @@ from typing import ( ) from django.core.exceptions import FieldError +from django.db import models from django.db.models.base import Model from django.db.models.fields import AutoField, CharField, Field from django.db.models.fields.related import ForeignKey, RelatedField @@ -285,3 +286,15 @@ class DjangoContext: expected_types[field_name] = gfk_set_type return expected_types + + @cached_property + def model_base_classes(self) -> Set[str]: + model_classes = self.apps_registry.get_models() + + all_model_bases = set() + for model_cls in model_classes: + for base_cls in model_cls.mro(): + if issubclass(base_cls, models.Model): + all_model_bases.add(helpers.get_class_fullname(base_cls)) + + return all_model_bases diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index f1b6ad9..77685c1 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -255,8 +255,6 @@ def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionCont return cast(TypeChecker, ctx.api) -def get_all_model_mixins(api: TypeChecker) -> Set[str]: - basemodel_info = lookup_fully_qualified_typeinfo(api, fullnames.MODEL_CLASS_FULLNAME) - if basemodel_info is None: - return set() - return set(get_django_metadata(basemodel_info).get('model_mixins', dict).keys()) +def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool: + return (info.fullname() in django_context.model_base_classes + or info.has_base(fullnames.MODEL_CLASS_FULLNAME)) diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 3e485da..8406248 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -174,7 +174,7 @@ class NewSemanalDjangoPlugin(Plugin): if info.has_base(fullnames.FIELD_FULLNAME): return partial(fields.transform_into_proper_return_type, django_context=self.django_context) - if info.has_base(fullnames.MODEL_CLASS_FULLNAME): + if helpers.is_model_subclass_info(info, self.django_context): return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context) return None @@ -213,7 +213,8 @@ class NewSemanalDjangoPlugin(Plugin): def get_base_class_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: - if fullname in self._get_current_model_bases(): + if (fullname in self.django_context.model_base_classes + or fullname in self._get_current_model_bases()): return partial(transform_model_class, django_context=self.django_context) if fullname in self._get_current_manager_bases(): diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index f91548f..19d04ff 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -15,7 +15,7 @@ from mypy_django_plugin.lib import fullnames, helpers def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]: outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() if (outer_model_info is None - or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)): + or not helpers.is_model_subclass_info(outer_model_info, django_context)): return None field_name = None @@ -117,10 +117,9 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() if (outer_model_info is None - or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME) - and outer_model_info.fullname() not in helpers.get_all_model_mixins(helpers.get_typechecker_api(ctx))): - # not inside models.Model class + or not helpers.is_model_subclass_info(outer_model_info, django_context)): return ctx.default_return_type + assert isinstance(outer_model_info, TypeInfo) if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES): diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 9c72e42..9155aa3 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -218,18 +218,6 @@ class AddMetaOptionsAttribute(ModelClassInitializer): ])) -class RecordAllModelMixins(ModelClassInitializer): - def run(self) -> None: - basemodel_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.MODEL_CLASS_FULLNAME) - basemodel_metadata = helpers.get_django_metadata(basemodel_info) - if 'model_mixins' not in basemodel_metadata: - basemodel_metadata['model_mixins'] = {} - - for base_info in self.model_classdef.info.mro[1:]: - if base_info.fullname() != 'builtins.object': - basemodel_metadata['model_mixins'][base_info.fullname()] = 1 - - def process_model_class(ctx: ClassDefContext, django_context: DjangoContext) -> None: initializers = [ @@ -241,7 +229,6 @@ def process_model_class(ctx: ClassDefContext, AddRelatedManagers, AddExtraFieldMethods, AddMetaOptionsAttribute, - RecordAllModelMixins, ] for initializer_cls in initializers: try: diff --git a/test-data/typecheck/fields/test_base.yml b/test-data/typecheck/fields/test_base.yml index 9470eb6..48ce577 100644 --- a/test-data/typecheck/fields/test_base.yml +++ b/test-data/typecheck/fields/test_base.yml @@ -144,7 +144,10 @@ - path: myapp/models.py content: | from django.db import models - class AuthMixin: + class AuthMixin(models.Model): + class Meta: + abstract = True username = models.CharField(max_length=100) + class MyModel(AuthMixin, models.Model): pass \ No newline at end of file