Use runtime information to determine whether class is a models.Model subclass (#182)

This commit is contained in:
Maxim Kurnikov
2019-09-28 04:05:54 +03:00
committed by GitHub
parent 5910bd1b25
commit 2c23d8e70f
6 changed files with 26 additions and 25 deletions

View File

@@ -6,6 +6,7 @@ from typing import (
) )
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import models
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.fields import AutoField, CharField, Field from django.db.models.fields import AutoField, CharField, Field
from django.db.models.fields.related import ForeignKey, RelatedField from django.db.models.fields.related import ForeignKey, RelatedField
@@ -285,3 +286,15 @@ class DjangoContext:
expected_types[field_name] = gfk_set_type expected_types[field_name] = gfk_set_type
return expected_types return expected_types
@cached_property
def model_base_classes(self) -> Set[str]:
model_classes = self.apps_registry.get_models()
all_model_bases = set()
for model_cls in model_classes:
for base_cls in model_cls.mro():
if issubclass(base_cls, models.Model):
all_model_bases.add(helpers.get_class_fullname(base_cls))
return all_model_bases

View File

@@ -255,8 +255,6 @@ def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionCont
return cast(TypeChecker, ctx.api) return cast(TypeChecker, ctx.api)
def get_all_model_mixins(api: TypeChecker) -> Set[str]: def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool:
basemodel_info = lookup_fully_qualified_typeinfo(api, fullnames.MODEL_CLASS_FULLNAME) return (info.fullname() in django_context.model_base_classes
if basemodel_info is None: or info.has_base(fullnames.MODEL_CLASS_FULLNAME))
return set()
return set(get_django_metadata(basemodel_info).get('model_mixins', dict).keys())

View File

@@ -174,7 +174,7 @@ class NewSemanalDjangoPlugin(Plugin):
if info.has_base(fullnames.FIELD_FULLNAME): if info.has_base(fullnames.FIELD_FULLNAME):
return partial(fields.transform_into_proper_return_type, django_context=self.django_context) return partial(fields.transform_into_proper_return_type, django_context=self.django_context)
if info.has_base(fullnames.MODEL_CLASS_FULLNAME): if helpers.is_model_subclass_info(info, self.django_context):
return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context) return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)
return None return None
@@ -213,7 +213,8 @@ class NewSemanalDjangoPlugin(Plugin):
def get_base_class_hook(self, fullname: str def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]: ) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in self._get_current_model_bases(): if (fullname in self.django_context.model_base_classes
or fullname in self._get_current_model_bases()):
return partial(transform_model_class, django_context=self.django_context) return partial(transform_model_class, django_context=self.django_context)
if fullname in self._get_current_manager_bases(): if fullname in self._get_current_manager_bases():

View File

@@ -15,7 +15,7 @@ from mypy_django_plugin.lib import fullnames, helpers
def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]: def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]:
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
if (outer_model_info is None if (outer_model_info is None
or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)): or not helpers.is_model_subclass_info(outer_model_info, django_context)):
return None return None
field_name = None field_name = None
@@ -117,10 +117,9 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
if (outer_model_info is None if (outer_model_info is None
or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME) or not helpers.is_model_subclass_info(outer_model_info, django_context)):
and outer_model_info.fullname() not in helpers.get_all_model_mixins(helpers.get_typechecker_api(ctx))):
# not inside models.Model class
return ctx.default_return_type return ctx.default_return_type
assert isinstance(outer_model_info, TypeInfo) assert isinstance(outer_model_info, TypeInfo)
if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES): if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):

View File

@@ -218,18 +218,6 @@ class AddMetaOptionsAttribute(ModelClassInitializer):
])) ]))
class RecordAllModelMixins(ModelClassInitializer):
def run(self) -> None:
basemodel_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.MODEL_CLASS_FULLNAME)
basemodel_metadata = helpers.get_django_metadata(basemodel_info)
if 'model_mixins' not in basemodel_metadata:
basemodel_metadata['model_mixins'] = {}
for base_info in self.model_classdef.info.mro[1:]:
if base_info.fullname() != 'builtins.object':
basemodel_metadata['model_mixins'][base_info.fullname()] = 1
def process_model_class(ctx: ClassDefContext, def process_model_class(ctx: ClassDefContext,
django_context: DjangoContext) -> None: django_context: DjangoContext) -> None:
initializers = [ initializers = [
@@ -241,7 +229,6 @@ def process_model_class(ctx: ClassDefContext,
AddRelatedManagers, AddRelatedManagers,
AddExtraFieldMethods, AddExtraFieldMethods,
AddMetaOptionsAttribute, AddMetaOptionsAttribute,
RecordAllModelMixins,
] ]
for initializer_cls in initializers: for initializer_cls in initializers:
try: try:

View File

@@ -144,7 +144,10 @@
- path: myapp/models.py - path: myapp/models.py
content: | content: |
from django.db import models from django.db import models
class AuthMixin: class AuthMixin(models.Model):
class Meta:
abstract = True
username = models.CharField(max_length=100) username = models.CharField(max_length=100)
class MyModel(AuthMixin, models.Model): class MyModel(AuthMixin, models.Model):
pass pass