This commit is contained in:
Maxim Kurnikov
2020-03-15 00:58:11 +03:00
parent 0b1507c81e
commit 1419b144d9
20 changed files with 513 additions and 321 deletions

View File

@@ -42,6 +42,9 @@ _ST = TypeVar("_ST")
# __get__ return type # __get__ return type
_GT = TypeVar("_GT") _GT = TypeVar("_GT")
class CharField(Field[str, str]):
class Field(RegisterLookupMixin, Generic[_ST, _GT]): class Field(RegisterLookupMixin, Generic[_ST, _GT]):
_pyi_private_set_type: Any _pyi_private_set_type: Any
_pyi_private_get_type: Any _pyi_private_get_type: Any

113
mfile.py Normal file
View File

@@ -0,0 +1,113 @@
from graphviz import Digraph
from mypy.options import Options
source = """
from root.package import MyQuerySet
MyQuerySet().mymethod()
"""
from mypy import parse
parsed = parse.parse(source, 'myfile.py', None, None, Options())
print(parsed)
graphattrs = {
"labelloc": "t",
"fontcolor": "blue",
# "bgcolor": "#333333",
"margin": "0",
}
nodeattrs = {
# "color": "white",
"fontcolor": "#00008b",
# "style": "filled",
# "fillcolor": "#ffffff",
# "fillcolor": "#006699",
}
edgeattrs = {
# "color": "white",
# "fontcolor": "white",
}
graph = Digraph('mfile.py', graph_attr=graphattrs, node_attr=nodeattrs, edge_attr=edgeattrs)
graph.node('__builtins__')
graph.node('django.db.models')
graph.node('django.db.models.fields')
graph.edge('django.db.models', 'django.db.models.fields')
graph.edge('django.db.models', '__builtins__')
graph.edge('django.db.models.fields', '__builtins__')
graph.node('mymodule')
graph.edge('mymodule', 'django.db.models')
graph.edge('mymodule', '__builtins__')
#
# graph.node('ImportFrom', label='ImportFrom(val=root.package, [MyQuerySet])')
# graph.edge('MypyFile', 'ImportFrom')
# graph.node('ClassDef_MyQuerySet', label='ClassDef(name=MyQuerySet)')
# graph.edge('MypyFile', 'ClassDef_MyQuerySet')
#
# graph.node('FuncDef_mymethod', label='FuncDef(name=mymethod)')
# graph.edge('ClassDef_MyQuerySet', 'FuncDef_mymethod')
#
# graph.node('Args', label='Args')
# graph.edge('FuncDef_mymethod', 'Args')
#
# graph.node('Var_self', label='Var(name=self)')
# graph.edge('Args', 'Var_self')
#
# graph.node('Block', label='Block')
# graph.edge('FuncDef_mymethod', 'Block')
#
# graph.node('PassStmt')
# graph.edge('Block', 'PassStmt')
# graph.node('ExpressionStmt')
# graph.edge('MypyFile', 'ExpressionStmt')
#
# graph.node('CallExpr', label='CallExpr(val="MyQuerySet()")')
# graph.edge('ExpressionStmt', 'CallExpr')
#
# graph.node('MemberExpr', label='MemberExpr(val=".mymethod()")')
# graph.edge('CallExpr', 'MemberExpr')
#
# graph.node('CallExpr_outer_Args', label='Args()')
# graph.edge('CallExpr', 'CallExpr_outer_Args')
#
# graph.node('CallExpr_inner', label='CallExpr(val="mymethod()")')
# graph.edge('MemberExpr', 'CallExpr_inner')
#
# graph.node('NameExpr', label='NameExpr(val="mymethod")')
# graph.edge('CallExpr_inner', 'NameExpr')
#
# graph.node('Expression_Args', label='Args()')
# graph.edge('CallExpr_inner', 'Expression_Args')
graph.render(view=True, format='png')
# MypyFile(
# ClassDef(
# name=MyQuerySet,
# FuncDef(
# name=mymethod,
# Args(
# Var(self))
# Block(PassStmt())
# )
# )
# ExpressionStmt:6(
# CallExpr:6(
# MemberExpr:6(
# CallExpr:6(
# NameExpr(MyQuerySet)
# Args())
# mymethod)
# Args())))

13
mfile.py.gv Normal file
View File

@@ -0,0 +1,13 @@
digraph "mfile.py" {
graph [fontcolor=blue labelloc=t margin=0]
node [fontcolor="#00008b"]
__builtins__
"django.db.models"
"django.db.models.fields"
"django.db.models" -> "django.db.models.fields"
"django.db.models" -> __builtins__
"django.db.models.fields" -> __builtins__
mymodule
mymodule -> "django.db.models"
mymodule -> __builtins__
}

BIN
mfile.py.gv.pdf Normal file

Binary file not shown.

BIN
mfile.py.gv.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

9
my.gv Normal file
View File

@@ -0,0 +1,9 @@
digraph AST {
File
ClassDef
ClassDef -> File
FuncDef
FuncDef -> ClassDef
ExpressionStmt
ExpressionStmt -> File
}

BIN
my.gv.pdf Normal file

Binary file not shown.

View File

@@ -11,7 +11,7 @@ from django.db import models
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.fields import AutoField, CharField, Field from django.db.models.fields import AutoField, CharField, Field
from django.db.models.fields.related import ForeignKey, RelatedField from django.db.models.fields.related import ForeignKey, RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel from django.db.models.fields.reverse_related import ForeignObjectRel, ManyToOneRel, ManyToManyRel, OneToOneRel
from django.db.models.lookups import Exact from django.db.models.lookups import Exact
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
from django.utils.functional import cached_property from django.utils.functional import cached_property
@@ -119,10 +119,10 @@ class DjangoContext:
if isinstance(field, Field): if isinstance(field, Field):
yield field yield field
def get_model_relations(self, model_cls: Type[Model]) -> Iterator[ForeignObjectRel]: def get_model_relations(self, model_cls: Type[Model]) -> Iterator[Tuple[Optional[str], ForeignObjectRel]]:
for field in model_cls._meta.get_fields(): for relation in model_cls._meta.get_fields():
if isinstance(field, ForeignObjectRel): if isinstance(relation, ForeignObjectRel):
yield field yield relation.get_accessor_name(), relation
def get_field_lookup_exact_type(self, api: TypeChecker, field: Union[Field, ForeignObjectRel]) -> MypyType: def get_field_lookup_exact_type(self, api: TypeChecker, field: Union[Field, ForeignObjectRel]) -> MypyType:
if isinstance(field, (RelatedField, ForeignObjectRel)): if isinstance(field, (RelatedField, ForeignObjectRel)):

View File

@@ -10,11 +10,11 @@ from mypy.mro import calculate_mro
from mypy.nodes import ( from mypy.nodes import (
Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTable, SymbolTableNode, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTable, SymbolTableNode,
TypeInfo, Var, TypeInfo, Var,
CallExpr, Context, PlaceholderNode, FuncDef, FakeInfo) CallExpr, Context, PlaceholderNode, FuncDef, FakeInfo, OverloadedFuncDef, Decorator)
from mypy.plugin import DynamicClassDefContext, ClassDefContext from mypy.plugin import DynamicClassDefContext, ClassDefContext, AttributeContext, MethodContext
from mypy.plugins.common import add_method from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzer from mypy.semanal import SemanticAnalyzer, is_valid_replacement, is_same_symbol
from mypy.types import AnyType, Instance, NoneTyp, TypeType from mypy.types import AnyType, Instance, NoneTyp, TypeType, ProperType, CallableType
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy.types import TypeOfAny, UnionType from mypy.types import TypeOfAny, UnionType
from mypy.typetraverser import TypeTraverserVisitor from mypy.typetraverser import TypeTraverserVisitor
@@ -38,8 +38,25 @@ class DjangoPluginCallback:
self.plugin = plugin self.plugin = plugin
self.django_context = plugin.django_context self.django_context = plugin.django_context
# def lookup_fully_qualified(self, fullname: str) -> Optional[SymbolTableNode]: def new_typeinfo(self, name: str, bases: List[Instance]) -> TypeInfo:
# return self.plugin.lookup_fully_qualified(fullname) class_def = ClassDef(name, Block([]))
class_def.fullname = self.qualified_name(name)
info = TypeInfo(SymbolTable(), class_def, self.get_current_module().fullname)
info.bases = bases
calculate_mro(info)
info.metaclass_type = info.calculate_metaclass_type()
class_def.info = info
return info
@abstractmethod
def get_current_module(self) -> MypyFile:
raise NotImplementedError()
@abstractmethod
def qualified_name(self, name: str) -> str:
raise NotImplementedError()
class SemanalPluginCallback(DjangoPluginCallback): class SemanalPluginCallback(DjangoPluginCallback):
@@ -58,6 +75,12 @@ class SemanalPluginCallback(DjangoPluginCallback):
print(f'LOG: defer: {self.build_defer_error_message(reason)}') print(f'LOG: defer: {self.build_defer_error_message(reason)}')
return True return True
def get_current_module(self) -> MypyFile:
return self.semanal_api.cur_mod_node
def qualified_name(self, name: str) -> str:
return self.semanal_api.qualified_name(name)
def lookup_typeinfo_or_defer(self, fullname: str, *, def lookup_typeinfo_or_defer(self, fullname: str, *,
deferral_context: Optional[Context] = None, deferral_context: Optional[Context] = None,
reason_for_defer: Optional[str] = None) -> Optional[TypeInfo]: reason_for_defer: Optional[str] = None) -> Optional[TypeInfo]:
@@ -74,11 +97,12 @@ class SemanalPluginCallback(DjangoPluginCallback):
return sym.node return sym.node
def new_typeinfo(self, name: str, bases: List[Instance]) -> TypeInfo: def new_typeinfo(self, name: str, bases: List[Instance], module_fullname: Optional[str] = None) -> TypeInfo:
class_def = ClassDef(name, Block([])) class_def = ClassDef(name, Block([]))
class_def.fullname = self.semanal_api.qualified_name(name) class_def.fullname = self.semanal_api.qualified_name(name)
info = TypeInfo(SymbolTable(), class_def, self.semanal_api.cur_mod_id) info = TypeInfo(SymbolTable(), class_def,
module_fullname or self.get_current_module().fullname)
info.bases = bases info.bases = bases
calculate_mro(info) calculate_mro(info)
info.metaclass_type = info.calculate_metaclass_type() info.metaclass_type = info.calculate_metaclass_type()
@@ -86,6 +110,43 @@ class SemanalPluginCallback(DjangoPluginCallback):
class_def.info = info class_def.info = info
return info return info
def add_symbol_table_node(self,
name: str,
symbol: SymbolTableNode,
symbol_table: Optional[SymbolTable] = None,
context: Optional[Context] = None,
can_defer: bool = True,
escape_comprehensions: bool = False) -> None:
""" Patched copy of SemanticAnalyzer.add_symbol_table_node(). """
names = symbol_table or self.semanal_api.current_symbol_table(escape_comprehensions=escape_comprehensions)
existing = names.get(name)
if isinstance(symbol.node, PlaceholderNode) and can_defer:
self.semanal_api.defer(context)
return None
if (existing is not None
and context is not None
and not is_valid_replacement(existing, symbol)):
# There is an existing node, so this may be a redefinition.
# If the new node points to the same node as the old one,
# or if both old and new nodes are placeholders, we don't
# need to do anything.
old = existing.node
new = symbol.node
if isinstance(new, PlaceholderNode):
# We don't know whether this is okay. Let's wait until the next iteration.
return False
if not is_same_symbol(old, new):
if isinstance(new, (FuncDef, Decorator, OverloadedFuncDef, TypeInfo)):
self.semanal_api.add_redefinition(names, name, symbol)
if not (isinstance(new, (FuncDef, Decorator))
and self.semanal_api.set_original_def(old, new)):
self.semanal_api.name_already_defined(name, context, existing)
elif name not in self.semanal_api.missing_names and '*' not in self.semanal_api.missing_names:
names[name] = symbol
self.progress = True
return None
raise new_helpers.SymbolAdditionNotPossible()
# def add_symbol_table_node_or_defer(self, name: str, sym: SymbolTableNode) -> bool: # def add_symbol_table_node_or_defer(self, name: str, sym: SymbolTableNode) -> bool:
# return self.semanal_api.add_symbol_table_node(name, sym, # return self.semanal_api.add_symbol_table_node(name, sym,
# context=self.semanal_api.cur_mod_node) # context=self.semanal_api.cur_mod_node)
@@ -119,20 +180,6 @@ class SemanalPluginCallback(DjangoPluginCallback):
self.semanal_api.add_imported_symbol(name, sym, context=self.semanal_api.cur_mod_node) self.semanal_api.add_imported_symbol(name, sym, context=self.semanal_api.cur_mod_node)
class UnimportedTypesVisitor(TypeTraverserVisitor): class UnimportedTypesVisitor(TypeTraverserVisitor):
def visit_union_type(self, t: UnionType) -> None:
super().visit_union_type(t)
union_sym = currently_imported_symbols.get('Union')
if union_sym is None:
# TODO: check if it's exactly typing.Union
import_symbol_from_source('Union')
def visit_type_type(self, t: TypeType) -> None:
super().visit_type_type(t)
type_sym = currently_imported_symbols.get('Union')
if type_sym is None:
# TODO: check if it's exactly typing.Type
import_symbol_from_source('Type')
def visit_instance(self, t: Instance) -> None: def visit_instance(self, t: Instance) -> None:
super().visit_instance(t) super().visit_instance(t)
if isinstance(t.type, FakeInfo): if isinstance(t.type, FakeInfo):
@@ -140,7 +187,6 @@ class SemanalPluginCallback(DjangoPluginCallback):
type_name = t.type.name type_name = t.type.name
sym = currently_imported_symbols.get(type_name) sym = currently_imported_symbols.get(type_name)
if sym is None: if sym is None:
# TODO: check if it's exactly typing.Type
import_symbol_from_source(type_name) import_symbol_from_source(type_name)
signature_node.type.accept(UnimportedTypesVisitor()) signature_node.type.accept(UnimportedTypesVisitor())
@@ -202,11 +248,13 @@ class DynamicClassPluginCallback(SemanalPluginCallback):
class ClassDefPluginCallback(SemanalPluginCallback): class ClassDefPluginCallback(SemanalPluginCallback):
reason: Expression reason: Expression
class_defn: ClassDef class_defn: ClassDef
ctx: ClassDefContext
def __call__(self, ctx: ClassDefContext) -> None: def __call__(self, ctx: ClassDefContext) -> None:
self.reason = ctx.reason self.reason = ctx.reason
self.class_defn = ctx.cls self.class_defn = ctx.cls
self.semanal_api = cast(SemanticAnalyzer, ctx.api) self.semanal_api = cast(SemanticAnalyzer, ctx.api)
self.ctx = ctx
self.modify_class_defn() self.modify_class_defn()
@abstractmethod @abstractmethod
@@ -214,6 +262,64 @@ class ClassDefPluginCallback(SemanalPluginCallback):
raise NotImplementedError raise NotImplementedError
class TypeCheckerPluginCallback(DjangoPluginCallback):
type_checker: TypeChecker
def get_current_module(self) -> MypyFile:
current_module = None
for item in reversed(self.type_checker.scope.stack):
if isinstance(item, MypyFile):
current_module = item
break
assert current_module is not None
return current_module
def qualified_name(self, name: str) -> str:
return self.type_checker.scope.stack[-1].fullname + '.' + name
def lookup_typeinfo(self, fullname: str) -> Optional[TypeInfo]:
sym = self.plugin.lookup_fully_qualified(fullname)
if sym is None or sym.node is None:
return None
if not isinstance(sym.node, TypeInfo):
raise ValueError(f'{fullname!r} does not correspond to TypeInfo')
return sym.node
class GetMethodPluginCallback(TypeCheckerPluginCallback):
callee_type: Instance
ctx: MethodContext
def __call__(self, ctx: MethodContext) -> MypyType:
self.type_checker = ctx.api
assert isinstance(ctx.type, CallableType)
self.callee_type = ctx.type.ret_type
self.ctx = ctx
return self.get_method_return_type()
@abstractmethod
def get_method_return_type(self) -> MypyType:
raise NotImplementedError
class GetAttributeCallback(TypeCheckerPluginCallback):
obj_type: ProperType
default_attr_type: MypyType
error_context: MemberExpr
name: str
def __call__(self, ctx: AttributeContext) -> MypyType:
self.ctx = ctx
self.type_checker = ctx.api
self.obj_type = ctx.type
self.default_attr_type = ctx.default_attr_type
self.error_context = ctx.context
assert isinstance(self.error_context, MemberExpr)
self.name = self.error_context.name
return self.default_attr_type
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]: def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {}) return model_info.metadata.setdefault('django', {})

View File

@@ -17,11 +17,10 @@ from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import ( from mypy_django_plugin.transformers import (
fields, forms, init_create, meta, querysets, request, settings, fields, forms, init_create, meta, querysets, request, settings,
) )
from mypy_django_plugin.transformers.managers import (
create_manager_class_from_as_manager_method, instantiate_anonymous_queryset_from_as_manager)
from mypy_django_plugin.transformers.models import process_model_class from mypy_django_plugin.transformers.models import process_model_class
from mypy_django_plugin.transformers2.dynamic_managers import CreateNewManagerClassFrom_FromQuerySet from mypy_django_plugin.transformers2.dynamic_managers import CreateNewManagerClassFrom_FromQuerySet
from mypy_django_plugin.transformers2.models import ModelCallback from mypy_django_plugin.transformers2.models import ModelCallback
from mypy_django_plugin.transformers2.related_managers import GetRelatedManagerCallback
def transform_model_class(ctx: ClassDefContext, def transform_model_class(ctx: ClassDefContext,
@@ -176,10 +175,6 @@ class NewSemanalDjangoPlugin(Plugin):
if fullname == 'django.contrib.auth.get_user_model': if fullname == 'django.contrib.auth.get_user_model':
return partial(settings.get_user_model_hook, django_context=self.django_context) return partial(settings.get_user_model_hook, django_context=self.django_context)
# manager_bases = self._get_current_manager_bases()
# if fullname in manager_bases:
# return querysets.determine_proper_manager_type
info = self._get_typeinfo_or_none(fullname) info = self._get_typeinfo_or_none(fullname)
if info: if info:
if info.has_base(fullnames.FIELD_FULLNAME): if info.has_base(fullnames.FIELD_FULLNAME):
@@ -217,11 +212,6 @@ class NewSemanalDjangoPlugin(Plugin):
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME): if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context) return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)
if method_name == 'as_manager':
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return instantiate_anonymous_queryset_from_as_manager
manager_classes = self._get_current_manager_bases() manager_classes = self._get_current_manager_bases()
if class_fullname in manager_classes and method_name == 'create': if class_fullname in manager_classes and method_name == 'create':
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context) return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
@@ -253,6 +243,10 @@ class NewSemanalDjangoPlugin(Plugin):
info = self._get_typeinfo_or_none(class_name) info = self._get_typeinfo_or_none(class_name)
if info and info.has_base(fullnames.HTTPREQUEST_CLASS_FULLNAME) and attr_name == 'user': if info and info.has_base(fullnames.HTTPREQUEST_CLASS_FULLNAME) and attr_name == 'user':
return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context) return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context)
if info and info.has_base(fullnames.MODEL_CLASS_FULLNAME):
return GetRelatedManagerCallback(self)
return None return None
def get_dynamic_class_hook(self, fullname: str def get_dynamic_class_hook(self, fullname: str
@@ -263,12 +257,6 @@ class NewSemanalDjangoPlugin(Plugin):
if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME): if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
return CreateNewManagerClassFrom_FromQuerySet(self) return CreateNewManagerClassFrom_FromQuerySet(self)
if fullname.endswith('as_manager'):
class_name, _, _ = fullname.rpartition('.')
info = self._get_typeinfo_or_none(class_name)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return create_manager_class_from_as_manager_method
return None return None

View File

@@ -230,102 +230,3 @@ def add_symbol_table_node(api: SemanticAnalyzer,
return True return True
return False return False
class CreateNewManagerClassFrom_AsManager(helpers.DynamicClassPluginCallback):
def create_new_dynamic_class(self) -> None:
pass
def create_manager_class_from_as_manager_method(ctx: DynamicClassDefContext) -> None:
semanal_api = sem_helpers.get_semanal_api(ctx)
try:
queryset_info = resolve_callee_info_or_exception(ctx)
django_manager_info = resolve_django_manager_info_or_exception(ctx)
except sem_helpers.IncompleteDefnError:
if not semanal_api.final_iteration:
semanal_api.defer()
return
else:
raise
generic_param: MypyType = AnyType(TypeOfAny.explicit)
generic_param_name = 'Any'
if (semanal_api.scope.classes
and semanal_api.scope.classes[-1].has_base(fullnames.MODEL_CLASS_FULLNAME)):
info = semanal_api.scope.classes[-1] # type: TypeInfo
generic_param = Instance(info, [])
generic_param_name = info.name
new_manager_class_name = queryset_info.name + '_AsManager_' + generic_param_name
new_manager_info = helpers.new_typeinfo(new_manager_class_name,
bases=[Instance(django_manager_info, [generic_param])],
module_name=semanal_api.cur_mod_id)
new_manager_info.set_line(ctx.call)
record_new_manager_info_fullname_into_metadata(ctx,
new_manager_info.fullname,
django_manager_info,
queryset_info,
django_manager_info)
class_def_context = ClassDefContext(cls=new_manager_info.defn,
reason=ctx.call, api=semanal_api)
self_type = Instance(new_manager_info, [AnyType(TypeOfAny.explicit)])
try:
for name, method_node in iter_all_custom_queryset_methods(queryset_info):
sem_helpers.copy_method_or_incomplete_defn_exception(class_def_context,
self_type,
new_method_name=name,
method_node=method_node)
except sem_helpers.IncompleteDefnError:
if not semanal_api.final_iteration:
semanal_api.defer()
return
else:
raise
new_manager_sym = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
# context=None - forcibly replace old node
added = add_symbol_table_node(semanal_api, new_manager_class_name, new_manager_sym,
context=None,
symbol_table=semanal_api.globals)
if added:
# replace all references to the old manager Var everywhere
for _, module in semanal_api.modules.items():
if module.fullname != semanal_api.cur_mod_id:
for sym_name, sym in module.names.items():
if sym.fullname == new_manager_info.fullname:
module.names[sym_name] = new_manager_sym.copy()
# we need another iteration to process methods
if (not added
and not semanal_api.final_iteration):
semanal_api.defer()
def instantiate_anonymous_queryset_from_as_manager(ctx: MethodContext) -> MypyType:
api = chk_helpers.get_typechecker_api(ctx)
django_manager_info = helpers.lookup_fully_qualified_typeinfo(api, fullnames.MANAGER_CLASS_FULLNAME)
assert django_manager_info is not None
assert isinstance(ctx.type, CallableType)
assert isinstance(ctx.type.ret_type, Instance)
queryset_info = ctx.type.ret_type.type
gen_name = django_manager_info.name + 'From' + queryset_info.name
gen_fullname = 'django.db.models.manager' + '.' + gen_name
metadata = get_generated_managers_metadata(django_manager_info)
if gen_fullname not in metadata:
raise ValueError(f'{gen_fullname!r} is not present in generated managers list')
module_name, _, class_name = metadata[gen_fullname].rpartition('.')
current_module = helpers.get_current_module(api)
assert module_name == current_module.fullname
generated_manager_info = current_module.names[class_name].node
assert isinstance(generated_manager_info, TypeInfo)
return Instance(generated_manager_info, [])

View File

@@ -4,11 +4,17 @@ from mypy.checker import gen_unique_name
from mypy.nodes import NameExpr, TypeInfo, SymbolTableNode, StrExpr from mypy.nodes import NameExpr, TypeInfo, SymbolTableNode, StrExpr
from mypy.types import Type as MypyType, TypeVarType, TypeVarDef, Instance from mypy.types import Type as MypyType, TypeVarType, TypeVarDef, Instance
from mypy_django_plugin.lib import helpers, fullnames, chk_helpers, sem_helpers from mypy_django_plugin.lib import helpers, fullnames
from mypy_django_plugin.transformers.managers import iter_all_custom_queryset_methods from mypy_django_plugin.transformers.managers import iter_all_custom_queryset_methods
class CreateNewManagerClassFrom_FromQuerySet(helpers.DynamicClassPluginCallback): class CreateNewManagerClassFrom_FromQuerySet(helpers.DynamicClassPluginCallback):
def set_manager_mapping(self, runtime_manager_fullname: str, generated_manager_fullname: str) -> None:
base_model_info = self.lookup_typeinfo_or_defer(fullnames.MODEL_CLASS_FULLNAME)
assert base_model_info is not None
managers_metadata = base_model_info.metadata.setdefault('managers', {})
managers_metadata[runtime_manager_fullname] = generated_manager_fullname
def create_typevar_in_current_module(self, name: str, def create_typevar_in_current_module(self, name: str,
upper_bound: Optional[MypyType] = None) -> TypeVarDef: upper_bound: Optional[MypyType] = None) -> TypeVarDef:
tvar_name = gen_unique_name(name, self.semanal_api.globals) tvar_name = gen_unique_name(name, self.semanal_api.globals)
@@ -48,19 +54,20 @@ class CreateNewManagerClassFrom_FromQuerySet(helpers.DynamicClassPluginCallback)
parent_manager_type = Instance(callee_manager_info, [model_tvar_type]) parent_manager_type = Instance(callee_manager_info, [model_tvar_type])
# instantiate with a proper model, Manager[MyModel], filling all Manager type vars in process # instantiate with a proper model, Manager[MyModel], filling all Manager type vars in process
queryset_type = Instance(passed_queryset_info, [Instance(base_model_info, [])])
new_manager_info = self.new_typeinfo(self.class_name, new_manager_info = self.new_typeinfo(self.class_name,
bases=[parent_manager_type]) bases=[queryset_type, parent_manager_type])
new_manager_info.defn.type_vars = [model_tvar_defn] new_manager_info.defn.type_vars = [model_tvar_defn]
new_manager_info.type_vars = [model_tvar_defn.name] new_manager_info.type_vars = [model_tvar_defn.name]
new_manager_info.set_line(self.call_expr) new_manager_info.set_line(self.call_expr)
# copy methods from passed_queryset_info with self type replaced # copy methods from passed_queryset_info with self type replaced
self_type = Instance(new_manager_info, [model_tvar_type]) # self_type = Instance(new_manager_info, [model_tvar_type])
for name, method_node in iter_all_custom_queryset_methods(passed_queryset_info): # for name, method_node in iter_all_custom_queryset_methods(passed_queryset_info):
self.add_method_from_signature(method_node, # self.add_method_from_signature(method_node,
name, # name,
self_type, # self_type,
new_manager_info.defn) # new_manager_info.defn)
new_manager_sym = SymbolTableNode(self.semanal_api.current_symbol_kind(), new_manager_sym = SymbolTableNode(self.semanal_api.current_symbol_kind(),
new_manager_info, new_manager_info,
@@ -75,5 +82,5 @@ class CreateNewManagerClassFrom_FromQuerySet(helpers.DynamicClassPluginCallback)
runtime_manager_class_name = class_name_arg.value runtime_manager_class_name = class_name_arg.value
new_manager_name = runtime_manager_class_name or (callee_manager_info.name + 'From' + queryset_class_name) new_manager_name = runtime_manager_class_name or (callee_manager_info.name + 'From' + queryset_class_name)
django_generated_manager_name = 'django.db.models.manager.' + new_manager_name self.set_manager_mapping(f'django.db.models.manager.{new_manager_name}',
base_model_info.metadata.setdefault('managers', {})[django_generated_manager_name] = new_manager_info.fullname new_manager_info.fullname)

View File

@@ -3,16 +3,16 @@ from typing import Type, Optional
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.fields.related import OneToOneField, ForeignKey from django.db.models.fields.related import OneToOneField, ForeignKey
from django.db.models.fields.reverse_related import OneToOneRel, ManyToManyRel, ManyToOneRel from mypy.nodes import TypeInfo, Var, SymbolTableNode, MDEF, Argument, ARG_STAR2
from mypy.checker import gen_unique_name
from mypy.nodes import TypeInfo, Var, SymbolTableNode, MDEF
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.plugins import common
from mypy.semanal import dummy_context from mypy.semanal import dummy_context
from mypy.types import Instance, TypeOfAny, AnyType from mypy.types import Instance, TypeOfAny, AnyType
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from django.db import models from django.db import models
from mypy_django_plugin.lib import helpers, fullnames from django.db.models.fields import DateField, DateTimeField
from mypy_django_plugin.lib import helpers, fullnames, sem_helpers
from mypy_django_plugin.transformers import fields from mypy_django_plugin.transformers import fields
from mypy_django_plugin.transformers.fields import get_field_type from mypy_django_plugin.transformers.fields import get_field_type
from mypy_django_plugin.transformers2 import new_helpers from mypy_django_plugin.transformers2 import new_helpers
@@ -116,76 +116,77 @@ class AddPrimaryKeyIfDoesNotExist(TransformModelClassCallback):
class AddRelatedManagersCallback(TransformModelClassCallback): class AddRelatedManagersCallback(TransformModelClassCallback):
def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None:
for relation in self.django_context.get_model_relations(runtime_model_cls): for reverse_manager_name, relation in self.django_context.get_model_relations(runtime_model_cls):
reverse_manager_name = relation.get_accessor_name()
if (reverse_manager_name is None if (reverse_manager_name is None
or reverse_manager_name in self.class_defn.info.names): or reverse_manager_name in self.class_defn.info.names):
continue continue
related_model_cls = self.django_context.get_field_related_model_cls(relation) self.add_new_model_attribute(reverse_manager_name, AnyType(TypeOfAny.implementation_artifact))
if related_model_cls is None: #
# could not find a referenced model (maybe invalid to= value, or GenericForeignKey) # related_model_cls = self.django_context.get_field_related_model_cls(relation)
continue # if related_model_cls is None:
# # could not find a referenced model (maybe invalid to= value, or GenericForeignKey)
related_model_info = self.lookup_typeinfo_for_class_or_defer(related_model_cls) # continue
if related_model_info is None: #
continue # related_model_info = self.lookup_typeinfo_for_class_or_defer(related_model_cls)
# if related_model_info is None:
if isinstance(relation, OneToOneRel): # continue
self.add_new_model_attribute(reverse_manager_name, #
Instance(related_model_info, [])) # if isinstance(relation, OneToOneRel):
elif isinstance(relation, (ManyToOneRel, ManyToManyRel)): # self.add_new_model_attribute(reverse_manager_name,
related_manager_info = self.lookup_typeinfo_or_defer(fullnames.RELATED_MANAGER_CLASS) # Instance(related_model_info, []))
if related_manager_info is None: # elif isinstance(relation, (ManyToOneRel, ManyToManyRel)):
if not self.defer_till_next_iteration(self.class_defn, # related_manager_info = self.lookup_typeinfo_or_defer(fullnames.RELATED_MANAGER_CLASS)
reason=f'{fullnames.RELATED_MANAGER_CLASS!r} is not available for lookup'): # if related_manager_info is None:
raise TypeInfoNotFound(fullnames.RELATED_MANAGER_CLASS) # if not self.defer_till_next_iteration(self.class_defn,
continue # reason=f'{fullnames.RELATED_MANAGER_CLASS!r} is not available for lookup'):
# raise TypeInfoNotFound(fullnames.RELATED_MANAGER_CLASS)
# get type of default_manager for model # continue
default_manager_fullname = helpers.get_class_fullname(related_model_cls._meta.default_manager.__class__) #
reason_for_defer = (f'Trying to lookup default_manager {default_manager_fullname!r} ' # # get type of default_manager for model
f'of model {helpers.get_class_fullname(related_model_cls)!r}') # default_manager_fullname = helpers.get_class_fullname(related_model_cls._meta.default_manager.__class__)
default_manager_info = self.lookup_typeinfo_or_defer(default_manager_fullname, # reason_for_defer = (f'Trying to lookup default_manager {default_manager_fullname!r} '
reason_for_defer=reason_for_defer) # f'of model {helpers.get_class_fullname(related_model_cls)!r}')
if default_manager_info is None: # default_manager_info = self.lookup_typeinfo_or_defer(default_manager_fullname,
continue # reason_for_defer=reason_for_defer)
# if default_manager_info is None:
default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])]) # continue
#
# related_model_cls._meta.default_manager.__class__ # default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])])
# # we're making a subclass of 'objects', need to have it defined #
# if 'objects' not in related_model_info.names: # # related_model_cls._meta.default_manager.__class__
# if not self.defer_till_next_iteration(self.class_defn, # # # we're making a subclass of 'objects', need to have it defined
# reason=f"'objects' manager is not yet defined on {related_model_info.fullname!r}"): # # if 'objects' not in related_model_info.names:
# raise AttributeNotFound(self.class_defn.info, 'objects') # # if not self.defer_till_next_iteration(self.class_defn,
# continue # # reason=f"'objects' manager is not yet defined on {related_model_info.fullname!r}"):
# # raise AttributeNotFound(self.class_defn.info, 'objects')
related_manager_type = Instance(related_manager_info, # # continue
[Instance(related_model_info, [])]) #
# # related_manager_type = Instance(related_manager_info,
# objects_sym = related_model_info.names['objects'] # [Instance(related_model_info, [])])
# default_manager_type = objects_sym.type # #
# if default_manager_type is None: # # objects_sym = related_model_info.names['objects']
# # dynamic base class, extract from django_context # # default_manager_type = objects_sym.type
# default_manager_cls = related_model_cls._meta.default_manager.__class__ # # if default_manager_type is None:
# default_manager_info = self.lookup_typeinfo_for_class_or_defer(default_manager_cls) # # # dynamic base class, extract from django_context
# if default_manager_info is None: # # default_manager_cls = related_model_cls._meta.default_manager.__class__
# continue # # default_manager_info = self.lookup_typeinfo_for_class_or_defer(default_manager_cls)
# default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])]) # # if default_manager_info is None:
# # continue
if (not isinstance(default_manager_type, Instance) # # default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])])
or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME): #
# if not defined or trivial -> just return RelatedManager[Model] # if (not isinstance(default_manager_type, Instance)
self.add_new_model_attribute(reverse_manager_name, related_manager_type) # or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME):
continue # # if not defined or trivial -> just return RelatedManager[Model]
# self.add_new_model_attribute(reverse_manager_name, related_manager_type)
# make anonymous class # continue
name = gen_unique_name(related_model_cls.__name__ + '_' + 'RelatedManager', #
self.semanal_api.current_symbol_table()) # # make anonymous class
bases = [related_manager_type, default_manager_type] # name = gen_unique_name(related_model_cls.__name__ + '_' + 'RelatedManager',
new_manager_info = self.new_typeinfo(name, bases) # self.semanal_api.current_symbol_table())
self.add_new_model_attribute(reverse_manager_name, Instance(new_manager_info, [])) # bases = [related_manager_type, default_manager_type]
# new_manager_info = self.new_typeinfo(name, bases)
# self.add_new_model_attribute(reverse_manager_name, Instance(new_manager_info, []))
class AddForeignPrimaryKeys(TransformModelClassCallback): class AddForeignPrimaryKeys(TransformModelClassCallback):
@@ -222,6 +223,69 @@ class AddForeignPrimaryKeys(TransformModelClassCallback):
self.add_new_model_attribute(rel_pk_field_name, field_type) self.add_new_model_attribute(rel_pk_field_name, field_type)
class InjectAnyAsBaseForNestedMeta(TransformModelClassCallback):
"""
Replaces
class MyModel(models.Model):
class Meta:
pass
with
class MyModel(models.Model):
class Meta(Any):
pass
to get around incompatible Meta inner classes for different models.
"""
def modify_class_defn(self) -> None:
meta_node = sem_helpers.get_nested_meta_node_for_current_class(self.class_defn.info)
if meta_node is None:
return None
meta_node.fallback_to_any = True
class AddMetaOptionsAttribute(TransformModelClassCallback):
def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None:
if '_meta' not in self.class_defn.info.names:
options_info = self.lookup_typeinfo_or_defer(fullnames.OPTIONS_CLASS_FULLNAME)
if options_info is not None:
self.add_new_model_attribute('_meta',
Instance(options_info, [
Instance(self.class_defn.info, [])
]))
class AddExtraFieldMethods(TransformModelClassCallback):
def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None:
# get_FOO_display for choices
for field in self.django_context.get_model_fields(runtime_model_cls):
if field.choices:
info = self.lookup_typeinfo_or_defer('builtins.str')
return_type = Instance(info, [])
common.add_method(self.ctx,
name='get_{}_display'.format(field.attname),
args=[],
return_type=return_type)
# get_next_by, get_previous_by for Date, DateTime
for field in self.django_context.get_model_fields(runtime_model_cls):
if isinstance(field, (DateField, DateTimeField)) and not field.null:
return_type = Instance(self.class_defn.info, [])
common.add_method(self.ctx,
name='get_next_by_{}'.format(field.attname),
args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)),
AnyType(TypeOfAny.explicit),
initializer=None,
kind=ARG_STAR2)],
return_type=return_type)
common.add_method(self.ctx,
name='get_previous_by_{}'.format(field.attname),
args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)),
AnyType(TypeOfAny.explicit),
initializer=None,
kind=ARG_STAR2)],
return_type=return_type)
class ModelCallback(helpers.ClassDefPluginCallback): class ModelCallback(helpers.ClassDefPluginCallback):
def __call__(self, ctx: ClassDefContext) -> None: def __call__(self, ctx: ClassDefContext) -> None:
callback_classes = [ callback_classes = [
@@ -230,6 +294,9 @@ class ModelCallback(helpers.ClassDefPluginCallback):
AddForeignPrimaryKeys, AddForeignPrimaryKeys,
AddDefaultManagerCallback, AddDefaultManagerCallback,
AddRelatedManagersCallback, AddRelatedManagersCallback,
InjectAnyAsBaseForNestedMeta,
AddMetaOptionsAttribute,
AddExtraFieldMethods,
] ]
for callback_cls in callback_classes: for callback_cls in callback_classes:
callback = callback_cls(self.plugin) callback = callback_cls(self.plugin)

View File

@@ -22,5 +22,9 @@ class NameNotFound(IncompleteDefnError):
super().__init__(f'Could not find {name!r} in the current activated namespaces') super().__init__(f'Could not find {name!r} in the current activated namespaces')
class SymbolAdditionNotPossible(Exception):
pass
def get_class_fullname(klass: type) -> str: def get_class_fullname(klass: type) -> str:
return klass.__module__ + '.' + klass.__qualname__ return klass.__module__ + '.' + klass.__qualname__

View File

@@ -0,0 +1,69 @@
from mypy.checker import gen_unique_name
from mypy.plugin import AttributeContext
from mypy.types import Instance
from mypy.types import Type as MypyType
from django.db.models.fields.reverse_related import ForeignObjectRel, OneToOneRel, ManyToOneRel, ManyToManyRel
from mypy_django_plugin.lib import helpers, fullnames
from mypy_django_plugin.lib.helpers import GetAttributeCallback
class GetRelatedManagerCallback(GetAttributeCallback):
obj_type: Instance
def get_related_manager_type(self, relation: ForeignObjectRel) -> MypyType:
related_model_cls = self.django_context.get_field_related_model_cls(relation)
if related_model_cls is None:
# could not find a referenced model (maybe invalid to= value, or GenericForeignKey)
# TODO: show error
return self.default_attr_type
related_model_info = self.lookup_typeinfo(helpers.get_class_fullname(related_model_cls))
if related_model_info is None:
# TODO: show error
return self.default_attr_type
if isinstance(relation, OneToOneRel):
return Instance(related_model_info, [])
elif isinstance(relation, (ManyToOneRel, ManyToManyRel)):
related_manager_info = self.lookup_typeinfo(fullnames.RELATED_MANAGER_CLASS)
if related_manager_info is None:
return self.default_attr_type
# get type of default_manager for model
default_manager_fullname = helpers.get_class_fullname(related_model_cls._meta.default_manager.__class__)
default_manager_info = self.lookup_typeinfo(default_manager_fullname)
if default_manager_info is None:
return self.default_attr_type
default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])])
related_manager_type = Instance(related_manager_info,
[Instance(related_model_info, [])])
if (not isinstance(default_manager_type, Instance)
or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME):
# if not defined or trivial -> just return RelatedManager[Model]
return related_manager_type
# make anonymous class
name = gen_unique_name(related_model_cls.__name__ + '_' + 'RelatedManager',
self.obj_type.type.names)
bases = [related_manager_type, default_manager_type]
new_manager_info = self.new_typeinfo(name, bases)
return Instance(new_manager_info, [])
def __call__(self, ctx: AttributeContext):
super().__call__(ctx)
assert isinstance(self.obj_type, Instance)
model_fullname = self.obj_type.type.fullname
model_cls = self.django_context.get_model_class_by_fullname(model_fullname)
if model_cls is None:
return self.default_attr_type
for reverse_manager_name, relation in self.django_context.get_model_relations(model_cls):
if reverse_manager_name == self.name:
return self.get_related_manager_type(relation)
return self.default_attr_type

View File

@@ -653,7 +653,7 @@
- case: related_manager_is_a_subclass_of_default_manager - case: related_manager_is_a_subclass_of_default_manager
main: | main: |
from myapp.models import User from myapp.models import User
reveal_type(User().orders) # N: Revealed type is 'myapp.models.User.Order_RelatedManager' reveal_type(User().orders) # N: Revealed type is 'main.Order_RelatedManager'
reveal_type(User().orders.get()) # N: Revealed type is 'myapp.models.Order*' reveal_type(User().orders.get()) # N: Revealed type is 'myapp.models.Order*'
reveal_type(User().orders.manager_method()) # N: Revealed type is 'builtins.int' reveal_type(User().orders.manager_method()) # N: Revealed type is 'builtins.int'
installed_apps: installed_apps:

View File

@@ -1,95 +0,0 @@
- case: anonymous_queryset_from_as_manager_inside_model
main: |
from myapp.models import MyModel
reveal_type(MyModel.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_MyModel'
reveal_type(MyModel.objects.get()) # N: Revealed type is 'myapp.models.MyModel*'
reveal_type(MyModel.objects.queryset_method) # N: Revealed type is 'def () -> builtins.int'
reveal_type(MyModel.objects.queryset_method()) # N: Revealed type is 'builtins.int'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyQuerySet(models.QuerySet):
def queryset_method(self) -> int:
pass
class MyModel(models.Model):
objects = MyQuerySet.as_manager()
- case: two_invocations_parametrized_with_different_models
main: |
from myapp.models import User, Blog
reveal_type(User.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_User'
reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*'
reveal_type(User.objects.queryset_method) # N: Revealed type is 'def () -> builtins.int'
reveal_type(User.objects.queryset_method()) # N: Revealed type is 'builtins.int'
reveal_type(Blog.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_Blog'
reveal_type(Blog.objects.get()) # N: Revealed type is 'myapp.models.Blog*'
reveal_type(Blog.objects.queryset_method) # N: Revealed type is 'def () -> builtins.int'
reveal_type(Blog.objects.queryset_method()) # N: Revealed type is 'builtins.int'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyQuerySet(models.QuerySet):
def queryset_method(self) -> int:
pass
class User(models.Model):
objects = MyQuerySet.as_manager()
class Blog(models.Model):
objects = MyQuerySet.as_manager()
- case: as_manager_outside_model_parametrized_with_any
main: |
from myapp.models import NotModel, outside_objects
reveal_type(NotModel.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_Any'
reveal_type(NotModel.objects.get()) # N: Revealed type is 'Any'
reveal_type(outside_objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_Any'
reveal_type(outside_objects.get()) # N: Revealed type is 'Any'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyQuerySet(models.QuerySet):
def queryset_method(self) -> int:
pass
outside_objects = MyQuerySet.as_manager()
class NotModel:
objects = MyQuerySet.as_manager()
- case: test_as_manager_without_name_to_bind_in_different_files
main: |
from myapp.models import MyQuerySet
reveal_type(MyQuerySet.as_manager()) # N: Revealed type is 'Any'
reveal_type(MyQuerySet.as_manager().get()) # N: Revealed type is 'Any'
reveal_type(MyQuerySet.as_manager().mymethod()) # N: Revealed type is 'Any'
from myapp import helpers
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyQuerySet(models.QuerySet):
def mymethod(self) -> int:
pass
class MyModel(models.Model):
objects = MyQuerySet.as_manager()
- path: myapp/helpers.py
content: |
from myapp.models import MyQuerySet
MyQuerySet.as_manager()

View File

@@ -17,11 +17,15 @@
- path: myapp/__init__.py - path: myapp/__init__.py
- path: myapp/models.py - path: myapp/models.py
content: | content: |
from typing import TypeVar
from django.db import models from django.db import models
from django.db.models.manager import BaseManager, Manager from django.db.models.manager import BaseManager, Manager
from mypy_django_plugin.lib import generics from mypy_django_plugin.lib import generics
class ModelQuerySet(models.QuerySet): generics.make_classes_generic(models.QuerySet)
_M = TypeVar('_M', bound=models.Model)
class ModelQuerySet(models.QuerySet[_M]):
def queryset_method(self) -> str: def queryset_method(self) -> str:
return 'hello' return 'hello'

View File

@@ -0,0 +1,3 @@
digraph {
FuncDef [label="My FuncDef"]
}

Binary file not shown.