diff --git a/django-stubs/db/models/fields/__init__.pyi b/django-stubs/db/models/fields/__init__.pyi index 9797769..3c8d46e 100644 --- a/django-stubs/db/models/fields/__init__.pyi +++ b/django-stubs/db/models/fields/__init__.pyi @@ -42,6 +42,9 @@ _ST = TypeVar("_ST") # __get__ return type _GT = TypeVar("_GT") +class CharField(Field[str, str]): + + class Field(RegisterLookupMixin, Generic[_ST, _GT]): _pyi_private_set_type: Any _pyi_private_get_type: Any diff --git a/mfile.py b/mfile.py new file mode 100644 index 0000000..29f6885 --- /dev/null +++ b/mfile.py @@ -0,0 +1,113 @@ +from graphviz import Digraph +from mypy.options import Options + +source = """ +from root.package import MyQuerySet + +MyQuerySet().mymethod() +""" + +from mypy import parse + +parsed = parse.parse(source, 'myfile.py', None, None, Options()) +print(parsed) + +graphattrs = { + "labelloc": "t", + "fontcolor": "blue", + # "bgcolor": "#333333", + "margin": "0", +} + +nodeattrs = { + # "color": "white", + "fontcolor": "#00008b", + # "style": "filled", + # "fillcolor": "#ffffff", + # "fillcolor": "#006699", +} + +edgeattrs = { + # "color": "white", + # "fontcolor": "white", +} + +graph = Digraph('mfile.py', graph_attr=graphattrs, node_attr=nodeattrs, edge_attr=edgeattrs) +graph.node('__builtins__') + +graph.node('django.db.models') +graph.node('django.db.models.fields') + +graph.edge('django.db.models', 'django.db.models.fields') +graph.edge('django.db.models', '__builtins__') +graph.edge('django.db.models.fields', '__builtins__') + +graph.node('mymodule') +graph.edge('mymodule', 'django.db.models') +graph.edge('mymodule', '__builtins__') +# +# graph.node('ImportFrom', label='ImportFrom(val=root.package, [MyQuerySet])') +# graph.edge('MypyFile', 'ImportFrom') + + + +# graph.node('ClassDef_MyQuerySet', label='ClassDef(name=MyQuerySet)') +# graph.edge('MypyFile', 'ClassDef_MyQuerySet') +# +# graph.node('FuncDef_mymethod', label='FuncDef(name=mymethod)') +# graph.edge('ClassDef_MyQuerySet', 'FuncDef_mymethod') +# +# graph.node('Args', label='Args') +# graph.edge('FuncDef_mymethod', 'Args') +# +# graph.node('Var_self', label='Var(name=self)') +# graph.edge('Args', 'Var_self') +# +# graph.node('Block', label='Block') +# graph.edge('FuncDef_mymethod', 'Block') +# +# graph.node('PassStmt') +# graph.edge('Block', 'PassStmt') + +# graph.node('ExpressionStmt') +# graph.edge('MypyFile', 'ExpressionStmt') +# +# graph.node('CallExpr', label='CallExpr(val="MyQuerySet()")') +# graph.edge('ExpressionStmt', 'CallExpr') +# +# graph.node('MemberExpr', label='MemberExpr(val=".mymethod()")') +# graph.edge('CallExpr', 'MemberExpr') +# +# graph.node('CallExpr_outer_Args', label='Args()') +# graph.edge('CallExpr', 'CallExpr_outer_Args') +# +# graph.node('CallExpr_inner', label='CallExpr(val="mymethod()")') +# graph.edge('MemberExpr', 'CallExpr_inner') +# +# graph.node('NameExpr', label='NameExpr(val="mymethod")') +# graph.edge('CallExpr_inner', 'NameExpr') +# +# graph.node('Expression_Args', label='Args()') +# graph.edge('CallExpr_inner', 'Expression_Args') + +graph.render(view=True, format='png') + + +# MypyFile( +# ClassDef( +# name=MyQuerySet, +# FuncDef( +# name=mymethod, +# Args( +# Var(self)) +# Block(PassStmt()) +# ) +# ) +# ExpressionStmt:6( +# CallExpr:6( +# MemberExpr:6( +# CallExpr:6( +# NameExpr(MyQuerySet) +# Args()) +# mymethod) +# Args()))) diff --git a/mfile.py.gv b/mfile.py.gv new file mode 100644 index 0000000..58ca391 --- /dev/null +++ b/mfile.py.gv @@ -0,0 +1,13 @@ +digraph "mfile.py" { + graph [fontcolor=blue labelloc=t margin=0] + node [fontcolor="#00008b"] + __builtins__ + "django.db.models" + "django.db.models.fields" + "django.db.models" -> "django.db.models.fields" + "django.db.models" -> __builtins__ + "django.db.models.fields" -> __builtins__ + mymodule + mymodule -> "django.db.models" + mymodule -> __builtins__ +} diff --git a/mfile.py.gv.pdf b/mfile.py.gv.pdf new file mode 100644 index 0000000..2be078e Binary files /dev/null and b/mfile.py.gv.pdf differ diff --git a/mfile.py.gv.png b/mfile.py.gv.png new file mode 100644 index 0000000..a795d1b Binary files /dev/null and b/mfile.py.gv.png differ diff --git a/my.gv b/my.gv new file mode 100644 index 0000000..3e79d53 --- /dev/null +++ b/my.gv @@ -0,0 +1,9 @@ +digraph AST { + File + ClassDef + ClassDef -> File + FuncDef + FuncDef -> ClassDef + ExpressionStmt + ExpressionStmt -> File +} diff --git a/my.gv.pdf b/my.gv.pdf new file mode 100644 index 0000000..5fc254d Binary files /dev/null and b/my.gv.pdf differ diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index faa56b9..1f60226 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -11,7 +11,7 @@ from django.db import models from django.db.models.base import Model from django.db.models.fields import AutoField, CharField, Field from django.db.models.fields.related import ForeignKey, RelatedField -from django.db.models.fields.reverse_related import ForeignObjectRel +from django.db.models.fields.reverse_related import ForeignObjectRel, ManyToOneRel, ManyToManyRel, OneToOneRel from django.db.models.lookups import Exact from django.db.models.sql.query import Query from django.utils.functional import cached_property @@ -119,10 +119,10 @@ class DjangoContext: if isinstance(field, Field): yield field - def get_model_relations(self, model_cls: Type[Model]) -> Iterator[ForeignObjectRel]: - for field in model_cls._meta.get_fields(): - if isinstance(field, ForeignObjectRel): - yield field + def get_model_relations(self, model_cls: Type[Model]) -> Iterator[Tuple[Optional[str], ForeignObjectRel]]: + for relation in model_cls._meta.get_fields(): + if isinstance(relation, ForeignObjectRel): + yield relation.get_accessor_name(), relation def get_field_lookup_exact_type(self, api: TypeChecker, field: Union[Field, ForeignObjectRel]) -> MypyType: if isinstance(field, (RelatedField, ForeignObjectRel)): diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index cc68aa4..0e86466 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -10,11 +10,11 @@ from mypy.mro import calculate_mro from mypy.nodes import ( Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTable, SymbolTableNode, TypeInfo, Var, - CallExpr, Context, PlaceholderNode, FuncDef, FakeInfo) -from mypy.plugin import DynamicClassDefContext, ClassDefContext + CallExpr, Context, PlaceholderNode, FuncDef, FakeInfo, OverloadedFuncDef, Decorator) +from mypy.plugin import DynamicClassDefContext, ClassDefContext, AttributeContext, MethodContext from mypy.plugins.common import add_method -from mypy.semanal import SemanticAnalyzer -from mypy.types import AnyType, Instance, NoneTyp, TypeType +from mypy.semanal import SemanticAnalyzer, is_valid_replacement, is_same_symbol +from mypy.types import AnyType, Instance, NoneTyp, TypeType, ProperType, CallableType from mypy.types import Type as MypyType from mypy.types import TypeOfAny, UnionType from mypy.typetraverser import TypeTraverserVisitor @@ -38,8 +38,25 @@ class DjangoPluginCallback: self.plugin = plugin self.django_context = plugin.django_context - # def lookup_fully_qualified(self, fullname: str) -> Optional[SymbolTableNode]: - # return self.plugin.lookup_fully_qualified(fullname) + def new_typeinfo(self, name: str, bases: List[Instance]) -> TypeInfo: + class_def = ClassDef(name, Block([])) + class_def.fullname = self.qualified_name(name) + + info = TypeInfo(SymbolTable(), class_def, self.get_current_module().fullname) + info.bases = bases + calculate_mro(info) + info.metaclass_type = info.calculate_metaclass_type() + + class_def.info = info + return info + + @abstractmethod + def get_current_module(self) -> MypyFile: + raise NotImplementedError() + + @abstractmethod + def qualified_name(self, name: str) -> str: + raise NotImplementedError() class SemanalPluginCallback(DjangoPluginCallback): @@ -58,6 +75,12 @@ class SemanalPluginCallback(DjangoPluginCallback): print(f'LOG: defer: {self.build_defer_error_message(reason)}') return True + def get_current_module(self) -> MypyFile: + return self.semanal_api.cur_mod_node + + def qualified_name(self, name: str) -> str: + return self.semanal_api.qualified_name(name) + def lookup_typeinfo_or_defer(self, fullname: str, *, deferral_context: Optional[Context] = None, reason_for_defer: Optional[str] = None) -> Optional[TypeInfo]: @@ -74,11 +97,12 @@ class SemanalPluginCallback(DjangoPluginCallback): return sym.node - def new_typeinfo(self, name: str, bases: List[Instance]) -> TypeInfo: + def new_typeinfo(self, name: str, bases: List[Instance], module_fullname: Optional[str] = None) -> TypeInfo: class_def = ClassDef(name, Block([])) class_def.fullname = self.semanal_api.qualified_name(name) - info = TypeInfo(SymbolTable(), class_def, self.semanal_api.cur_mod_id) + info = TypeInfo(SymbolTable(), class_def, + module_fullname or self.get_current_module().fullname) info.bases = bases calculate_mro(info) info.metaclass_type = info.calculate_metaclass_type() @@ -86,6 +110,43 @@ class SemanalPluginCallback(DjangoPluginCallback): class_def.info = info return info + def add_symbol_table_node(self, + name: str, + symbol: SymbolTableNode, + symbol_table: Optional[SymbolTable] = None, + context: Optional[Context] = None, + can_defer: bool = True, + escape_comprehensions: bool = False) -> None: + """ Patched copy of SemanticAnalyzer.add_symbol_table_node(). """ + names = symbol_table or self.semanal_api.current_symbol_table(escape_comprehensions=escape_comprehensions) + existing = names.get(name) + if isinstance(symbol.node, PlaceholderNode) and can_defer: + self.semanal_api.defer(context) + return None + if (existing is not None + and context is not None + and not is_valid_replacement(existing, symbol)): + # There is an existing node, so this may be a redefinition. + # If the new node points to the same node as the old one, + # or if both old and new nodes are placeholders, we don't + # need to do anything. + old = existing.node + new = symbol.node + if isinstance(new, PlaceholderNode): + # We don't know whether this is okay. Let's wait until the next iteration. + return False + if not is_same_symbol(old, new): + if isinstance(new, (FuncDef, Decorator, OverloadedFuncDef, TypeInfo)): + self.semanal_api.add_redefinition(names, name, symbol) + if not (isinstance(new, (FuncDef, Decorator)) + and self.semanal_api.set_original_def(old, new)): + self.semanal_api.name_already_defined(name, context, existing) + elif name not in self.semanal_api.missing_names and '*' not in self.semanal_api.missing_names: + names[name] = symbol + self.progress = True + return None + raise new_helpers.SymbolAdditionNotPossible() + # def add_symbol_table_node_or_defer(self, name: str, sym: SymbolTableNode) -> bool: # return self.semanal_api.add_symbol_table_node(name, sym, # context=self.semanal_api.cur_mod_node) @@ -119,20 +180,6 @@ class SemanalPluginCallback(DjangoPluginCallback): self.semanal_api.add_imported_symbol(name, sym, context=self.semanal_api.cur_mod_node) class UnimportedTypesVisitor(TypeTraverserVisitor): - def visit_union_type(self, t: UnionType) -> None: - super().visit_union_type(t) - union_sym = currently_imported_symbols.get('Union') - if union_sym is None: - # TODO: check if it's exactly typing.Union - import_symbol_from_source('Union') - - def visit_type_type(self, t: TypeType) -> None: - super().visit_type_type(t) - type_sym = currently_imported_symbols.get('Union') - if type_sym is None: - # TODO: check if it's exactly typing.Type - import_symbol_from_source('Type') - def visit_instance(self, t: Instance) -> None: super().visit_instance(t) if isinstance(t.type, FakeInfo): @@ -140,7 +187,6 @@ class SemanalPluginCallback(DjangoPluginCallback): type_name = t.type.name sym = currently_imported_symbols.get(type_name) if sym is None: - # TODO: check if it's exactly typing.Type import_symbol_from_source(type_name) signature_node.type.accept(UnimportedTypesVisitor()) @@ -202,11 +248,13 @@ class DynamicClassPluginCallback(SemanalPluginCallback): class ClassDefPluginCallback(SemanalPluginCallback): reason: Expression class_defn: ClassDef + ctx: ClassDefContext def __call__(self, ctx: ClassDefContext) -> None: self.reason = ctx.reason self.class_defn = ctx.cls self.semanal_api = cast(SemanticAnalyzer, ctx.api) + self.ctx = ctx self.modify_class_defn() @abstractmethod @@ -214,6 +262,64 @@ class ClassDefPluginCallback(SemanalPluginCallback): raise NotImplementedError +class TypeCheckerPluginCallback(DjangoPluginCallback): + type_checker: TypeChecker + + def get_current_module(self) -> MypyFile: + current_module = None + for item in reversed(self.type_checker.scope.stack): + if isinstance(item, MypyFile): + current_module = item + break + assert current_module is not None + return current_module + + def qualified_name(self, name: str) -> str: + return self.type_checker.scope.stack[-1].fullname + '.' + name + + def lookup_typeinfo(self, fullname: str) -> Optional[TypeInfo]: + sym = self.plugin.lookup_fully_qualified(fullname) + if sym is None or sym.node is None: + return None + if not isinstance(sym.node, TypeInfo): + raise ValueError(f'{fullname!r} does not correspond to TypeInfo') + return sym.node + + +class GetMethodPluginCallback(TypeCheckerPluginCallback): + callee_type: Instance + ctx: MethodContext + + def __call__(self, ctx: MethodContext) -> MypyType: + self.type_checker = ctx.api + + assert isinstance(ctx.type, CallableType) + self.callee_type = ctx.type.ret_type + self.ctx = ctx + return self.get_method_return_type() + + @abstractmethod + def get_method_return_type(self) -> MypyType: + raise NotImplementedError + + +class GetAttributeCallback(TypeCheckerPluginCallback): + obj_type: ProperType + default_attr_type: MypyType + error_context: MemberExpr + name: str + + def __call__(self, ctx: AttributeContext) -> MypyType: + self.ctx = ctx + self.type_checker = ctx.api + self.obj_type = ctx.type + self.default_attr_type = ctx.default_attr_type + self.error_context = ctx.context + assert isinstance(self.error_context, MemberExpr) + self.name = self.error_context.name + return self.default_attr_type + + def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]: return model_info.metadata.setdefault('django', {}) diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 5302ec2..70f8d8f 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -17,11 +17,10 @@ from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.transformers import ( fields, forms, init_create, meta, querysets, request, settings, ) -from mypy_django_plugin.transformers.managers import ( - create_manager_class_from_as_manager_method, instantiate_anonymous_queryset_from_as_manager) from mypy_django_plugin.transformers.models import process_model_class from mypy_django_plugin.transformers2.dynamic_managers import CreateNewManagerClassFrom_FromQuerySet from mypy_django_plugin.transformers2.models import ModelCallback +from mypy_django_plugin.transformers2.related_managers import GetRelatedManagerCallback def transform_model_class(ctx: ClassDefContext, @@ -176,10 +175,6 @@ class NewSemanalDjangoPlugin(Plugin): if fullname == 'django.contrib.auth.get_user_model': return partial(settings.get_user_model_hook, django_context=self.django_context) - # manager_bases = self._get_current_manager_bases() - # if fullname in manager_bases: - # return querysets.determine_proper_manager_type - info = self._get_typeinfo_or_none(fullname) if info: if info.has_base(fullnames.FIELD_FULLNAME): @@ -217,11 +212,6 @@ class NewSemanalDjangoPlugin(Plugin): if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME): return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context) - if method_name == 'as_manager': - info = self._get_typeinfo_or_none(class_fullname) - if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): - return instantiate_anonymous_queryset_from_as_manager - manager_classes = self._get_current_manager_bases() if class_fullname in manager_classes and method_name == 'create': return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context) @@ -253,6 +243,10 @@ class NewSemanalDjangoPlugin(Plugin): info = self._get_typeinfo_or_none(class_name) if info and info.has_base(fullnames.HTTPREQUEST_CLASS_FULLNAME) and attr_name == 'user': return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context) + + if info and info.has_base(fullnames.MODEL_CLASS_FULLNAME): + return GetRelatedManagerCallback(self) + return None def get_dynamic_class_hook(self, fullname: str @@ -263,12 +257,6 @@ class NewSemanalDjangoPlugin(Plugin): if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME): return CreateNewManagerClassFrom_FromQuerySet(self) - if fullname.endswith('as_manager'): - class_name, _, _ = fullname.rpartition('.') - info = self._get_typeinfo_or_none(class_name) - if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): - return create_manager_class_from_as_manager_method - return None diff --git a/mypy_django_plugin/transformers/managers.py b/mypy_django_plugin/transformers/managers.py index 9bb7cea..9731b14 100644 --- a/mypy_django_plugin/transformers/managers.py +++ b/mypy_django_plugin/transformers/managers.py @@ -230,102 +230,3 @@ def add_symbol_table_node(api: SemanticAnalyzer, return True return False - -class CreateNewManagerClassFrom_AsManager(helpers.DynamicClassPluginCallback): - def create_new_dynamic_class(self) -> None: - pass - - -def create_manager_class_from_as_manager_method(ctx: DynamicClassDefContext) -> None: - semanal_api = sem_helpers.get_semanal_api(ctx) - try: - queryset_info = resolve_callee_info_or_exception(ctx) - django_manager_info = resolve_django_manager_info_or_exception(ctx) - except sem_helpers.IncompleteDefnError: - if not semanal_api.final_iteration: - semanal_api.defer() - return - else: - raise - - generic_param: MypyType = AnyType(TypeOfAny.explicit) - generic_param_name = 'Any' - if (semanal_api.scope.classes - and semanal_api.scope.classes[-1].has_base(fullnames.MODEL_CLASS_FULLNAME)): - info = semanal_api.scope.classes[-1] # type: TypeInfo - generic_param = Instance(info, []) - generic_param_name = info.name - - new_manager_class_name = queryset_info.name + '_AsManager_' + generic_param_name - new_manager_info = helpers.new_typeinfo(new_manager_class_name, - bases=[Instance(django_manager_info, [generic_param])], - module_name=semanal_api.cur_mod_id) - new_manager_info.set_line(ctx.call) - - record_new_manager_info_fullname_into_metadata(ctx, - new_manager_info.fullname, - django_manager_info, - queryset_info, - django_manager_info) - - class_def_context = ClassDefContext(cls=new_manager_info.defn, - reason=ctx.call, api=semanal_api) - self_type = Instance(new_manager_info, [AnyType(TypeOfAny.explicit)]) - - try: - for name, method_node in iter_all_custom_queryset_methods(queryset_info): - sem_helpers.copy_method_or_incomplete_defn_exception(class_def_context, - self_type, - new_method_name=name, - method_node=method_node) - except sem_helpers.IncompleteDefnError: - if not semanal_api.final_iteration: - semanal_api.defer() - return - else: - raise - - new_manager_sym = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True) - - # context=None - forcibly replace old node - added = add_symbol_table_node(semanal_api, new_manager_class_name, new_manager_sym, - context=None, - symbol_table=semanal_api.globals) - if added: - # replace all references to the old manager Var everywhere - for _, module in semanal_api.modules.items(): - if module.fullname != semanal_api.cur_mod_id: - for sym_name, sym in module.names.items(): - if sym.fullname == new_manager_info.fullname: - module.names[sym_name] = new_manager_sym.copy() - - # we need another iteration to process methods - if (not added - and not semanal_api.final_iteration): - semanal_api.defer() - - -def instantiate_anonymous_queryset_from_as_manager(ctx: MethodContext) -> MypyType: - api = chk_helpers.get_typechecker_api(ctx) - django_manager_info = helpers.lookup_fully_qualified_typeinfo(api, fullnames.MANAGER_CLASS_FULLNAME) - assert django_manager_info is not None - - assert isinstance(ctx.type, CallableType) - assert isinstance(ctx.type.ret_type, Instance) - queryset_info = ctx.type.ret_type.type - - gen_name = django_manager_info.name + 'From' + queryset_info.name - gen_fullname = 'django.db.models.manager' + '.' + gen_name - - metadata = get_generated_managers_metadata(django_manager_info) - if gen_fullname not in metadata: - raise ValueError(f'{gen_fullname!r} is not present in generated managers list') - - module_name, _, class_name = metadata[gen_fullname].rpartition('.') - current_module = helpers.get_current_module(api) - assert module_name == current_module.fullname - - generated_manager_info = current_module.names[class_name].node - assert isinstance(generated_manager_info, TypeInfo) - - return Instance(generated_manager_info, []) diff --git a/mypy_django_plugin/transformers2/dynamic_managers.py b/mypy_django_plugin/transformers2/dynamic_managers.py index 9c927e5..275f780 100644 --- a/mypy_django_plugin/transformers2/dynamic_managers.py +++ b/mypy_django_plugin/transformers2/dynamic_managers.py @@ -4,11 +4,17 @@ from mypy.checker import gen_unique_name from mypy.nodes import NameExpr, TypeInfo, SymbolTableNode, StrExpr from mypy.types import Type as MypyType, TypeVarType, TypeVarDef, Instance -from mypy_django_plugin.lib import helpers, fullnames, chk_helpers, sem_helpers +from mypy_django_plugin.lib import helpers, fullnames from mypy_django_plugin.transformers.managers import iter_all_custom_queryset_methods class CreateNewManagerClassFrom_FromQuerySet(helpers.DynamicClassPluginCallback): + def set_manager_mapping(self, runtime_manager_fullname: str, generated_manager_fullname: str) -> None: + base_model_info = self.lookup_typeinfo_or_defer(fullnames.MODEL_CLASS_FULLNAME) + assert base_model_info is not None + managers_metadata = base_model_info.metadata.setdefault('managers', {}) + managers_metadata[runtime_manager_fullname] = generated_manager_fullname + def create_typevar_in_current_module(self, name: str, upper_bound: Optional[MypyType] = None) -> TypeVarDef: tvar_name = gen_unique_name(name, self.semanal_api.globals) @@ -48,19 +54,20 @@ class CreateNewManagerClassFrom_FromQuerySet(helpers.DynamicClassPluginCallback) parent_manager_type = Instance(callee_manager_info, [model_tvar_type]) # instantiate with a proper model, Manager[MyModel], filling all Manager type vars in process + queryset_type = Instance(passed_queryset_info, [Instance(base_model_info, [])]) new_manager_info = self.new_typeinfo(self.class_name, - bases=[parent_manager_type]) + bases=[queryset_type, parent_manager_type]) new_manager_info.defn.type_vars = [model_tvar_defn] new_manager_info.type_vars = [model_tvar_defn.name] new_manager_info.set_line(self.call_expr) # copy methods from passed_queryset_info with self type replaced - self_type = Instance(new_manager_info, [model_tvar_type]) - for name, method_node in iter_all_custom_queryset_methods(passed_queryset_info): - self.add_method_from_signature(method_node, - name, - self_type, - new_manager_info.defn) + # self_type = Instance(new_manager_info, [model_tvar_type]) + # for name, method_node in iter_all_custom_queryset_methods(passed_queryset_info): + # self.add_method_from_signature(method_node, + # name, + # self_type, + # new_manager_info.defn) new_manager_sym = SymbolTableNode(self.semanal_api.current_symbol_kind(), new_manager_info, @@ -75,5 +82,5 @@ class CreateNewManagerClassFrom_FromQuerySet(helpers.DynamicClassPluginCallback) runtime_manager_class_name = class_name_arg.value new_manager_name = runtime_manager_class_name or (callee_manager_info.name + 'From' + queryset_class_name) - django_generated_manager_name = 'django.db.models.manager.' + new_manager_name - base_model_info.metadata.setdefault('managers', {})[django_generated_manager_name] = new_manager_info.fullname + self.set_manager_mapping(f'django.db.models.manager.{new_manager_name}', + new_manager_info.fullname) diff --git a/mypy_django_plugin/transformers2/models.py b/mypy_django_plugin/transformers2/models.py index 12962e6..a8814d0 100644 --- a/mypy_django_plugin/transformers2/models.py +++ b/mypy_django_plugin/transformers2/models.py @@ -3,16 +3,16 @@ from typing import Type, Optional from django.db.models.base import Model from django.db.models.fields.related import OneToOneField, ForeignKey -from django.db.models.fields.reverse_related import OneToOneRel, ManyToManyRel, ManyToOneRel -from mypy.checker import gen_unique_name -from mypy.nodes import TypeInfo, Var, SymbolTableNode, MDEF +from mypy.nodes import TypeInfo, Var, SymbolTableNode, MDEF, Argument, ARG_STAR2 from mypy.plugin import ClassDefContext +from mypy.plugins import common from mypy.semanal import dummy_context from mypy.types import Instance, TypeOfAny, AnyType from mypy.types import Type as MypyType from django.db import models -from mypy_django_plugin.lib import helpers, fullnames +from django.db.models.fields import DateField, DateTimeField +from mypy_django_plugin.lib import helpers, fullnames, sem_helpers from mypy_django_plugin.transformers import fields from mypy_django_plugin.transformers.fields import get_field_type from mypy_django_plugin.transformers2 import new_helpers @@ -116,76 +116,77 @@ class AddPrimaryKeyIfDoesNotExist(TransformModelClassCallback): class AddRelatedManagersCallback(TransformModelClassCallback): def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: - for relation in self.django_context.get_model_relations(runtime_model_cls): - reverse_manager_name = relation.get_accessor_name() + for reverse_manager_name, relation in self.django_context.get_model_relations(runtime_model_cls): if (reverse_manager_name is None or reverse_manager_name in self.class_defn.info.names): continue - related_model_cls = self.django_context.get_field_related_model_cls(relation) - if related_model_cls is None: - # could not find a referenced model (maybe invalid to= value, or GenericForeignKey) - continue - - related_model_info = self.lookup_typeinfo_for_class_or_defer(related_model_cls) - if related_model_info is None: - continue - - if isinstance(relation, OneToOneRel): - self.add_new_model_attribute(reverse_manager_name, - Instance(related_model_info, [])) - elif isinstance(relation, (ManyToOneRel, ManyToManyRel)): - related_manager_info = self.lookup_typeinfo_or_defer(fullnames.RELATED_MANAGER_CLASS) - if related_manager_info is None: - if not self.defer_till_next_iteration(self.class_defn, - reason=f'{fullnames.RELATED_MANAGER_CLASS!r} is not available for lookup'): - raise TypeInfoNotFound(fullnames.RELATED_MANAGER_CLASS) - continue - - # get type of default_manager for model - default_manager_fullname = helpers.get_class_fullname(related_model_cls._meta.default_manager.__class__) - reason_for_defer = (f'Trying to lookup default_manager {default_manager_fullname!r} ' - f'of model {helpers.get_class_fullname(related_model_cls)!r}') - default_manager_info = self.lookup_typeinfo_or_defer(default_manager_fullname, - reason_for_defer=reason_for_defer) - if default_manager_info is None: - continue - - default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])]) - - # related_model_cls._meta.default_manager.__class__ - # # we're making a subclass of 'objects', need to have it defined - # if 'objects' not in related_model_info.names: - # if not self.defer_till_next_iteration(self.class_defn, - # reason=f"'objects' manager is not yet defined on {related_model_info.fullname!r}"): - # raise AttributeNotFound(self.class_defn.info, 'objects') - # continue - - related_manager_type = Instance(related_manager_info, - [Instance(related_model_info, [])]) - # - # objects_sym = related_model_info.names['objects'] - # default_manager_type = objects_sym.type - # if default_manager_type is None: - # # dynamic base class, extract from django_context - # default_manager_cls = related_model_cls._meta.default_manager.__class__ - # default_manager_info = self.lookup_typeinfo_for_class_or_defer(default_manager_cls) - # if default_manager_info is None: - # continue - # default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])]) - - if (not isinstance(default_manager_type, Instance) - or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME): - # if not defined or trivial -> just return RelatedManager[Model] - self.add_new_model_attribute(reverse_manager_name, related_manager_type) - continue - - # make anonymous class - name = gen_unique_name(related_model_cls.__name__ + '_' + 'RelatedManager', - self.semanal_api.current_symbol_table()) - bases = [related_manager_type, default_manager_type] - new_manager_info = self.new_typeinfo(name, bases) - self.add_new_model_attribute(reverse_manager_name, Instance(new_manager_info, [])) + self.add_new_model_attribute(reverse_manager_name, AnyType(TypeOfAny.implementation_artifact)) + # + # related_model_cls = self.django_context.get_field_related_model_cls(relation) + # if related_model_cls is None: + # # could not find a referenced model (maybe invalid to= value, or GenericForeignKey) + # continue + # + # related_model_info = self.lookup_typeinfo_for_class_or_defer(related_model_cls) + # if related_model_info is None: + # continue + # + # if isinstance(relation, OneToOneRel): + # self.add_new_model_attribute(reverse_manager_name, + # Instance(related_model_info, [])) + # elif isinstance(relation, (ManyToOneRel, ManyToManyRel)): + # related_manager_info = self.lookup_typeinfo_or_defer(fullnames.RELATED_MANAGER_CLASS) + # if related_manager_info is None: + # if not self.defer_till_next_iteration(self.class_defn, + # reason=f'{fullnames.RELATED_MANAGER_CLASS!r} is not available for lookup'): + # raise TypeInfoNotFound(fullnames.RELATED_MANAGER_CLASS) + # continue + # + # # get type of default_manager for model + # default_manager_fullname = helpers.get_class_fullname(related_model_cls._meta.default_manager.__class__) + # reason_for_defer = (f'Trying to lookup default_manager {default_manager_fullname!r} ' + # f'of model {helpers.get_class_fullname(related_model_cls)!r}') + # default_manager_info = self.lookup_typeinfo_or_defer(default_manager_fullname, + # reason_for_defer=reason_for_defer) + # if default_manager_info is None: + # continue + # + # default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])]) + # + # # related_model_cls._meta.default_manager.__class__ + # # # we're making a subclass of 'objects', need to have it defined + # # if 'objects' not in related_model_info.names: + # # if not self.defer_till_next_iteration(self.class_defn, + # # reason=f"'objects' manager is not yet defined on {related_model_info.fullname!r}"): + # # raise AttributeNotFound(self.class_defn.info, 'objects') + # # continue + # + # related_manager_type = Instance(related_manager_info, + # [Instance(related_model_info, [])]) + # # + # # objects_sym = related_model_info.names['objects'] + # # default_manager_type = objects_sym.type + # # if default_manager_type is None: + # # # dynamic base class, extract from django_context + # # default_manager_cls = related_model_cls._meta.default_manager.__class__ + # # default_manager_info = self.lookup_typeinfo_for_class_or_defer(default_manager_cls) + # # if default_manager_info is None: + # # continue + # # default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])]) + # + # if (not isinstance(default_manager_type, Instance) + # or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME): + # # if not defined or trivial -> just return RelatedManager[Model] + # self.add_new_model_attribute(reverse_manager_name, related_manager_type) + # continue + # + # # make anonymous class + # name = gen_unique_name(related_model_cls.__name__ + '_' + 'RelatedManager', + # self.semanal_api.current_symbol_table()) + # bases = [related_manager_type, default_manager_type] + # new_manager_info = self.new_typeinfo(name, bases) + # self.add_new_model_attribute(reverse_manager_name, Instance(new_manager_info, [])) class AddForeignPrimaryKeys(TransformModelClassCallback): @@ -222,6 +223,69 @@ class AddForeignPrimaryKeys(TransformModelClassCallback): self.add_new_model_attribute(rel_pk_field_name, field_type) +class InjectAnyAsBaseForNestedMeta(TransformModelClassCallback): + """ + Replaces + class MyModel(models.Model): + class Meta: + pass + with + class MyModel(models.Model): + class Meta(Any): + pass + to get around incompatible Meta inner classes for different models. + """ + + def modify_class_defn(self) -> None: + meta_node = sem_helpers.get_nested_meta_node_for_current_class(self.class_defn.info) + if meta_node is None: + return None + meta_node.fallback_to_any = True + + +class AddMetaOptionsAttribute(TransformModelClassCallback): + def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: + if '_meta' not in self.class_defn.info.names: + options_info = self.lookup_typeinfo_or_defer(fullnames.OPTIONS_CLASS_FULLNAME) + if options_info is not None: + self.add_new_model_attribute('_meta', + Instance(options_info, [ + Instance(self.class_defn.info, []) + ])) + + +class AddExtraFieldMethods(TransformModelClassCallback): + def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: + # get_FOO_display for choices + for field in self.django_context.get_model_fields(runtime_model_cls): + if field.choices: + info = self.lookup_typeinfo_or_defer('builtins.str') + return_type = Instance(info, []) + common.add_method(self.ctx, + name='get_{}_display'.format(field.attname), + args=[], + return_type=return_type) + + # get_next_by, get_previous_by for Date, DateTime + for field in self.django_context.get_model_fields(runtime_model_cls): + if isinstance(field, (DateField, DateTimeField)) and not field.null: + return_type = Instance(self.class_defn.info, []) + common.add_method(self.ctx, + name='get_next_by_{}'.format(field.attname), + args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)), + AnyType(TypeOfAny.explicit), + initializer=None, + kind=ARG_STAR2)], + return_type=return_type) + common.add_method(self.ctx, + name='get_previous_by_{}'.format(field.attname), + args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)), + AnyType(TypeOfAny.explicit), + initializer=None, + kind=ARG_STAR2)], + return_type=return_type) + + class ModelCallback(helpers.ClassDefPluginCallback): def __call__(self, ctx: ClassDefContext) -> None: callback_classes = [ @@ -230,6 +294,9 @@ class ModelCallback(helpers.ClassDefPluginCallback): AddForeignPrimaryKeys, AddDefaultManagerCallback, AddRelatedManagersCallback, + InjectAnyAsBaseForNestedMeta, + AddMetaOptionsAttribute, + AddExtraFieldMethods, ] for callback_cls in callback_classes: callback = callback_cls(self.plugin) diff --git a/mypy_django_plugin/transformers2/new_helpers.py b/mypy_django_plugin/transformers2/new_helpers.py index 87f6958..ff1121c 100644 --- a/mypy_django_plugin/transformers2/new_helpers.py +++ b/mypy_django_plugin/transformers2/new_helpers.py @@ -22,5 +22,9 @@ class NameNotFound(IncompleteDefnError): super().__init__(f'Could not find {name!r} in the current activated namespaces') +class SymbolAdditionNotPossible(Exception): + pass + + def get_class_fullname(klass: type) -> str: return klass.__module__ + '.' + klass.__qualname__ diff --git a/mypy_django_plugin/transformers2/related_managers.py b/mypy_django_plugin/transformers2/related_managers.py new file mode 100644 index 0000000..b5c6cce --- /dev/null +++ b/mypy_django_plugin/transformers2/related_managers.py @@ -0,0 +1,69 @@ +from mypy.checker import gen_unique_name +from mypy.plugin import AttributeContext +from mypy.types import Instance +from mypy.types import Type as MypyType + +from django.db.models.fields.reverse_related import ForeignObjectRel, OneToOneRel, ManyToOneRel, ManyToManyRel + +from mypy_django_plugin.lib import helpers, fullnames +from mypy_django_plugin.lib.helpers import GetAttributeCallback + + +class GetRelatedManagerCallback(GetAttributeCallback): + obj_type: Instance + + def get_related_manager_type(self, relation: ForeignObjectRel) -> MypyType: + related_model_cls = self.django_context.get_field_related_model_cls(relation) + if related_model_cls is None: + # could not find a referenced model (maybe invalid to= value, or GenericForeignKey) + # TODO: show error + return self.default_attr_type + + related_model_info = self.lookup_typeinfo(helpers.get_class_fullname(related_model_cls)) + if related_model_info is None: + # TODO: show error + return self.default_attr_type + + if isinstance(relation, OneToOneRel): + return Instance(related_model_info, []) + + elif isinstance(relation, (ManyToOneRel, ManyToManyRel)): + related_manager_info = self.lookup_typeinfo(fullnames.RELATED_MANAGER_CLASS) + if related_manager_info is None: + return self.default_attr_type + + # get type of default_manager for model + default_manager_fullname = helpers.get_class_fullname(related_model_cls._meta.default_manager.__class__) + default_manager_info = self.lookup_typeinfo(default_manager_fullname) + if default_manager_info is None: + return self.default_attr_type + + default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])]) + related_manager_type = Instance(related_manager_info, + [Instance(related_model_info, [])]) + + if (not isinstance(default_manager_type, Instance) + or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME): + # if not defined or trivial -> just return RelatedManager[Model] + return related_manager_type + + # make anonymous class + name = gen_unique_name(related_model_cls.__name__ + '_' + 'RelatedManager', + self.obj_type.type.names) + bases = [related_manager_type, default_manager_type] + new_manager_info = self.new_typeinfo(name, bases) + return Instance(new_manager_info, []) + + def __call__(self, ctx: AttributeContext): + super().__call__(ctx) + assert isinstance(self.obj_type, Instance) + + model_fullname = self.obj_type.type.fullname + model_cls = self.django_context.get_model_class_by_fullname(model_fullname) + if model_cls is None: + return self.default_attr_type + for reverse_manager_name, relation in self.django_context.get_model_relations(model_cls): + if reverse_manager_name == self.name: + return self.get_related_manager_type(relation) + + return self.default_attr_type diff --git a/test-data/typecheck/fields/test_related.yml b/test-data/typecheck/fields/test_related.yml index a757d37..518aba9 100644 --- a/test-data/typecheck/fields/test_related.yml +++ b/test-data/typecheck/fields/test_related.yml @@ -653,7 +653,7 @@ - case: related_manager_is_a_subclass_of_default_manager main: | from myapp.models import User - reveal_type(User().orders) # N: Revealed type is 'myapp.models.User.Order_RelatedManager' + reveal_type(User().orders) # N: Revealed type is 'main.Order_RelatedManager' reveal_type(User().orders.get()) # N: Revealed type is 'myapp.models.Order*' reveal_type(User().orders.manager_method()) # N: Revealed type is 'builtins.int' installed_apps: diff --git a/test-data/typecheck/managers/querysets/test_as_manager.yml b/test-data/typecheck/managers/querysets/test_as_manager.yml deleted file mode 100644 index fe2c20d..0000000 --- a/test-data/typecheck/managers/querysets/test_as_manager.yml +++ /dev/null @@ -1,95 +0,0 @@ -- case: anonymous_queryset_from_as_manager_inside_model - main: | - from myapp.models import MyModel - - reveal_type(MyModel.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_MyModel' - reveal_type(MyModel.objects.get()) # N: Revealed type is 'myapp.models.MyModel*' - reveal_type(MyModel.objects.queryset_method) # N: Revealed type is 'def () -> builtins.int' - reveal_type(MyModel.objects.queryset_method()) # N: Revealed type is 'builtins.int' - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class MyQuerySet(models.QuerySet): - def queryset_method(self) -> int: - pass - class MyModel(models.Model): - objects = MyQuerySet.as_manager() - - -- case: two_invocations_parametrized_with_different_models - main: | - from myapp.models import User, Blog - reveal_type(User.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_User' - reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*' - reveal_type(User.objects.queryset_method) # N: Revealed type is 'def () -> builtins.int' - reveal_type(User.objects.queryset_method()) # N: Revealed type is 'builtins.int' - - reveal_type(Blog.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_Blog' - reveal_type(Blog.objects.get()) # N: Revealed type is 'myapp.models.Blog*' - reveal_type(Blog.objects.queryset_method) # N: Revealed type is 'def () -> builtins.int' - reveal_type(Blog.objects.queryset_method()) # N: Revealed type is 'builtins.int' - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class MyQuerySet(models.QuerySet): - def queryset_method(self) -> int: - pass - class User(models.Model): - objects = MyQuerySet.as_manager() - class Blog(models.Model): - objects = MyQuerySet.as_manager() - - -- case: as_manager_outside_model_parametrized_with_any - main: | - from myapp.models import NotModel, outside_objects - reveal_type(NotModel.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_Any' - reveal_type(NotModel.objects.get()) # N: Revealed type is 'Any' - reveal_type(outside_objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_Any' - reveal_type(outside_objects.get()) # N: Revealed type is 'Any' - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class MyQuerySet(models.QuerySet): - def queryset_method(self) -> int: - pass - outside_objects = MyQuerySet.as_manager() - class NotModel: - objects = MyQuerySet.as_manager() - -- case: test_as_manager_without_name_to_bind_in_different_files - main: | - from myapp.models import MyQuerySet - reveal_type(MyQuerySet.as_manager()) # N: Revealed type is 'Any' - reveal_type(MyQuerySet.as_manager().get()) # N: Revealed type is 'Any' - reveal_type(MyQuerySet.as_manager().mymethod()) # N: Revealed type is 'Any' - - from myapp import helpers - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class MyQuerySet(models.QuerySet): - def mymethod(self) -> int: - pass - class MyModel(models.Model): - objects = MyQuerySet.as_manager() - - path: myapp/helpers.py - content: | - from myapp.models import MyQuerySet - MyQuerySet.as_manager() \ No newline at end of file diff --git a/test-data/typecheck/managers/querysets/test_from_queryset.yml b/test-data/typecheck/managers/querysets/test_from_queryset.yml index 83bd390..a839e3c 100644 --- a/test-data/typecheck/managers/querysets/test_from_queryset.yml +++ b/test-data/typecheck/managers/querysets/test_from_queryset.yml @@ -17,11 +17,15 @@ - path: myapp/__init__.py - path: myapp/models.py content: | + from typing import TypeVar from django.db import models from django.db.models.manager import BaseManager, Manager from mypy_django_plugin.lib import generics - class ModelQuerySet(models.QuerySet): + generics.make_classes_generic(models.QuerySet) + _M = TypeVar('_M', bound=models.Model) + + class ModelQuerySet(models.QuerySet[_M]): def queryset_method(self) -> str: return 'hello' diff --git a/test-output/round-table.gv b/test-output/round-table.gv new file mode 100644 index 0000000..9281e83 --- /dev/null +++ b/test-output/round-table.gv @@ -0,0 +1,3 @@ +digraph { + FuncDef [label="My FuncDef"] +} diff --git a/test-output/round-table.gv.pdf b/test-output/round-table.gv.pdf new file mode 100644 index 0000000..3478b08 Binary files /dev/null and b/test-output/round-table.gv.pdf differ