from abc import abstractmethod from typing import Type, Optional from django.db.models.base import Model from django.db.models.fields.related import OneToOneField, ForeignKey from django.db.models.fields.reverse_related import OneToOneRel, ManyToManyRel, ManyToOneRel from mypy.checker import gen_unique_name from mypy.nodes import TypeInfo, Var, SymbolTableNode, MDEF from mypy.plugin import ClassDefContext from mypy.semanal import dummy_context from mypy.types import Instance, TypeOfAny, AnyType from mypy.types import Type as MypyType from django.db import models from mypy_django_plugin.lib import helpers, fullnames from mypy_django_plugin.transformers import fields from mypy_django_plugin.transformers.fields import get_field_type from mypy_django_plugin.transformers2 import new_helpers class TransformModelClassCallback(helpers.ClassDefPluginCallback): def get_real_manager_fullname(self, manager_fullname: str) -> str: model_info = self.lookup_typeinfo_or_defer(fullnames.MODEL_CLASS_FULLNAME) real_manager_fullname = model_info.metadata.get('managers', {}).get(manager_fullname, manager_fullname) return real_manager_fullname def modify_class_defn(self) -> None: model_cls = self.django_context.get_model_class_by_fullname(self.class_defn.fullname) if model_cls is None: return None return self.modify_model_class_defn(model_cls) def add_new_model_attribute(self, name: str, typ: MypyType, force_replace: bool = False) -> None: model_info = self.class_defn.info if name in model_info.names and not force_replace: raise ValueError('Attribute already exists on the model') var = Var(name, type=typ) var.info = model_info var._fullname = self.semanal_api.qualified_name(name) var.is_initialized_in_class = True sym = SymbolTableNode(MDEF, var, plugin_generated=True) error_context = None if force_replace else dummy_context() added = self.semanal_api.add_symbol_table_node(name, sym, context=error_context) assert added def lookup_typeinfo_for_class_or_defer(self, klass: type, *, reason_for_defer: Optional[str] = None) -> Optional[TypeInfo]: manager_cls_fullname = helpers.get_class_fullname(klass) return self.lookup_typeinfo_or_defer(manager_cls_fullname, reason_for_defer=reason_for_defer) @abstractmethod def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: raise NotImplementedError class AddDefaultManagerCallback(TransformModelClassCallback): def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: if ('_default_manager' in self.class_defn.info.names or runtime_model_cls._meta.default_manager is None): return None runtime_default_manager_class = runtime_model_cls._meta.default_manager.__class__ runtime_manager_cls_fullname = new_helpers.get_class_fullname(runtime_default_manager_class) manager_cls_fullname = self.get_real_manager_fullname(runtime_manager_cls_fullname) default_manager_info = self.lookup_typeinfo_or_defer(manager_cls_fullname) if default_manager_info is None: return self.add_new_model_attribute('_default_manager', Instance(default_manager_info, [Instance(self.class_defn.info, [])])) class AddManagersCallback(TransformModelClassCallback): def modify_model_class_defn(self, runtime_model_cls: Type[models.Model]) -> None: for manager_name, manager in runtime_model_cls._meta.managers_map.items(): if manager_name in self.class_defn.info.names: # already defined on the current model class, in file or at a previous iteration continue manager_info = self.lookup_typeinfo_for_class_or_defer(manager.__class__) if manager_info is None: continue manager_type = Instance(manager_info, [Instance(self.class_defn.info, [])]) self.add_new_model_attribute(manager_name, manager_type) class AddPrimaryKeyIfDoesNotExist(TransformModelClassCallback): """ Adds default primary key to models which does not define their own. class User(models.Model): name = models.TextField() """ def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: auto_pk_field = runtime_model_cls._meta.auto_field if auto_pk_field is None: # defined explicitly return None auto_pk_field_name = auto_pk_field.attname if auto_pk_field_name in self.class_defn.info.names: # added on previous iteration return None auto_pk_field_info = self.lookup_typeinfo_for_class_or_defer(auto_pk_field.__class__) if auto_pk_field_info is None: return None self.add_new_model_attribute(auto_pk_field_name, fields.get_field_type(auto_pk_field_info, is_nullable=False)) class AddRelatedManagersCallback(TransformModelClassCallback): def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: for relation in self.django_context.get_model_relations(runtime_model_cls): reverse_manager_name = relation.get_accessor_name() if (reverse_manager_name is None or reverse_manager_name in self.class_defn.info.names): continue related_model_cls = self.django_context.get_field_related_model_cls(relation) if related_model_cls is None: # could not find a referenced model (maybe invalid to= value, or GenericForeignKey) continue related_model_info = self.lookup_typeinfo_for_class_or_defer(related_model_cls) if related_model_info is None: continue if isinstance(relation, OneToOneRel): self.add_new_model_attribute(reverse_manager_name, Instance(related_model_info, [])) elif isinstance(relation, (ManyToOneRel, ManyToManyRel)): related_manager_info = self.lookup_typeinfo_or_defer(fullnames.RELATED_MANAGER_CLASS) if related_manager_info is None: if not self.defer_till_next_iteration(self.class_defn, reason=f'{fullnames.RELATED_MANAGER_CLASS!r} is not available for lookup'): raise TypeInfoNotFound(fullnames.RELATED_MANAGER_CLASS) continue # get type of default_manager for model default_manager_fullname = helpers.get_class_fullname(related_model_cls._meta.default_manager.__class__) reason_for_defer = (f'Trying to lookup default_manager {default_manager_fullname!r} ' f'of model {helpers.get_class_fullname(related_model_cls)!r}') default_manager_info = self.lookup_typeinfo_or_defer(default_manager_fullname, reason_for_defer=reason_for_defer) if default_manager_info is None: continue default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])]) # related_model_cls._meta.default_manager.__class__ # # we're making a subclass of 'objects', need to have it defined # if 'objects' not in related_model_info.names: # if not self.defer_till_next_iteration(self.class_defn, # reason=f"'objects' manager is not yet defined on {related_model_info.fullname!r}"): # raise AttributeNotFound(self.class_defn.info, 'objects') # continue related_manager_type = Instance(related_manager_info, [Instance(related_model_info, [])]) # # objects_sym = related_model_info.names['objects'] # default_manager_type = objects_sym.type # if default_manager_type is None: # # dynamic base class, extract from django_context # default_manager_cls = related_model_cls._meta.default_manager.__class__ # default_manager_info = self.lookup_typeinfo_for_class_or_defer(default_manager_cls) # if default_manager_info is None: # continue # default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])]) if (not isinstance(default_manager_type, Instance) or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME): # if not defined or trivial -> just return RelatedManager[Model] self.add_new_model_attribute(reverse_manager_name, related_manager_type) continue # make anonymous class name = gen_unique_name(related_model_cls.__name__ + '_' + 'RelatedManager', self.semanal_api.current_symbol_table()) bases = [related_manager_type, default_manager_type] new_manager_info = self.new_typeinfo(name, bases) self.add_new_model_attribute(reverse_manager_name, Instance(new_manager_info, [])) class AddForeignPrimaryKeys(TransformModelClassCallback): def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: for field in runtime_model_cls._meta.get_fields(): if not isinstance(field, (OneToOneField, ForeignKey)): continue rel_pk_field_name = field.attname if rel_pk_field_name in self.class_defn.info.names: continue related_model_cls = self.django_context.get_field_related_model_cls(field) if related_model_cls is None: field_sym = self.class_defn.info.get(field.name) if field_sym is not None and field_sym.node is not None: error_context = field_sym.node else: error_context = self.class_defn self.semanal_api.fail(f'Cannot find model {field.related_model!r} ' f'referenced in field {field.name!r} ', ctx=error_context) self.add_new_model_attribute(rel_pk_field_name, AnyType(TypeOfAny.from_error)) continue if related_model_cls._meta.abstract: continue rel_pk_field = self.django_context.get_primary_key_field(related_model_cls) rel_pk_field_info = self.lookup_typeinfo_for_class_or_defer(rel_pk_field.__class__) if rel_pk_field_info is None: continue field_type = get_field_type(rel_pk_field_info, is_nullable=self.django_context.get_field_nullability(field)) self.add_new_model_attribute(rel_pk_field_name, field_type) class ModelCallback(helpers.ClassDefPluginCallback): def __call__(self, ctx: ClassDefContext) -> None: callback_classes = [ AddManagersCallback, AddPrimaryKeyIfDoesNotExist, AddForeignPrimaryKeys, AddDefaultManagerCallback, AddRelatedManagersCallback, ] for callback_cls in callback_classes: callback = callback_cls(self.plugin) callback.__call__(ctx)