This commit is contained in:
Maxim Kurnikov
2020-03-15 00:58:11 +03:00
parent 0b1507c81e
commit 1419b144d9
20 changed files with 513 additions and 321 deletions

View File

@@ -42,6 +42,9 @@ _ST = TypeVar("_ST")
# __get__ return type
_GT = TypeVar("_GT")
class CharField(Field[str, str]):
class Field(RegisterLookupMixin, Generic[_ST, _GT]):
_pyi_private_set_type: Any
_pyi_private_get_type: Any

113
mfile.py Normal file
View File

@@ -0,0 +1,113 @@
from graphviz import Digraph
from mypy.options import Options
source = """
from root.package import MyQuerySet
MyQuerySet().mymethod()
"""
from mypy import parse
parsed = parse.parse(source, 'myfile.py', None, None, Options())
print(parsed)
graphattrs = {
"labelloc": "t",
"fontcolor": "blue",
# "bgcolor": "#333333",
"margin": "0",
}
nodeattrs = {
# "color": "white",
"fontcolor": "#00008b",
# "style": "filled",
# "fillcolor": "#ffffff",
# "fillcolor": "#006699",
}
edgeattrs = {
# "color": "white",
# "fontcolor": "white",
}
graph = Digraph('mfile.py', graph_attr=graphattrs, node_attr=nodeattrs, edge_attr=edgeattrs)
graph.node('__builtins__')
graph.node('django.db.models')
graph.node('django.db.models.fields')
graph.edge('django.db.models', 'django.db.models.fields')
graph.edge('django.db.models', '__builtins__')
graph.edge('django.db.models.fields', '__builtins__')
graph.node('mymodule')
graph.edge('mymodule', 'django.db.models')
graph.edge('mymodule', '__builtins__')
#
# graph.node('ImportFrom', label='ImportFrom(val=root.package, [MyQuerySet])')
# graph.edge('MypyFile', 'ImportFrom')
# graph.node('ClassDef_MyQuerySet', label='ClassDef(name=MyQuerySet)')
# graph.edge('MypyFile', 'ClassDef_MyQuerySet')
#
# graph.node('FuncDef_mymethod', label='FuncDef(name=mymethod)')
# graph.edge('ClassDef_MyQuerySet', 'FuncDef_mymethod')
#
# graph.node('Args', label='Args')
# graph.edge('FuncDef_mymethod', 'Args')
#
# graph.node('Var_self', label='Var(name=self)')
# graph.edge('Args', 'Var_self')
#
# graph.node('Block', label='Block')
# graph.edge('FuncDef_mymethod', 'Block')
#
# graph.node('PassStmt')
# graph.edge('Block', 'PassStmt')
# graph.node('ExpressionStmt')
# graph.edge('MypyFile', 'ExpressionStmt')
#
# graph.node('CallExpr', label='CallExpr(val="MyQuerySet()")')
# graph.edge('ExpressionStmt', 'CallExpr')
#
# graph.node('MemberExpr', label='MemberExpr(val=".mymethod()")')
# graph.edge('CallExpr', 'MemberExpr')
#
# graph.node('CallExpr_outer_Args', label='Args()')
# graph.edge('CallExpr', 'CallExpr_outer_Args')
#
# graph.node('CallExpr_inner', label='CallExpr(val="mymethod()")')
# graph.edge('MemberExpr', 'CallExpr_inner')
#
# graph.node('NameExpr', label='NameExpr(val="mymethod")')
# graph.edge('CallExpr_inner', 'NameExpr')
#
# graph.node('Expression_Args', label='Args()')
# graph.edge('CallExpr_inner', 'Expression_Args')
graph.render(view=True, format='png')
# MypyFile(
# ClassDef(
# name=MyQuerySet,
# FuncDef(
# name=mymethod,
# Args(
# Var(self))
# Block(PassStmt())
# )
# )
# ExpressionStmt:6(
# CallExpr:6(
# MemberExpr:6(
# CallExpr:6(
# NameExpr(MyQuerySet)
# Args())
# mymethod)
# Args())))

13
mfile.py.gv Normal file
View File

@@ -0,0 +1,13 @@
digraph "mfile.py" {
graph [fontcolor=blue labelloc=t margin=0]
node [fontcolor="#00008b"]
__builtins__
"django.db.models"
"django.db.models.fields"
"django.db.models" -> "django.db.models.fields"
"django.db.models" -> __builtins__
"django.db.models.fields" -> __builtins__
mymodule
mymodule -> "django.db.models"
mymodule -> __builtins__
}

BIN
mfile.py.gv.pdf Normal file

Binary file not shown.

BIN
mfile.py.gv.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

9
my.gv Normal file
View File

@@ -0,0 +1,9 @@
digraph AST {
File
ClassDef
ClassDef -> File
FuncDef
FuncDef -> ClassDef
ExpressionStmt
ExpressionStmt -> File
}

BIN
my.gv.pdf Normal file

Binary file not shown.

View File

@@ -11,7 +11,7 @@ from django.db import models
from django.db.models.base import Model
from django.db.models.fields import AutoField, CharField, Field
from django.db.models.fields.related import ForeignKey, RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.fields.reverse_related import ForeignObjectRel, ManyToOneRel, ManyToManyRel, OneToOneRel
from django.db.models.lookups import Exact
from django.db.models.sql.query import Query
from django.utils.functional import cached_property
@@ -119,10 +119,10 @@ class DjangoContext:
if isinstance(field, Field):
yield field
def get_model_relations(self, model_cls: Type[Model]) -> Iterator[ForeignObjectRel]:
for field in model_cls._meta.get_fields():
if isinstance(field, ForeignObjectRel):
yield field
def get_model_relations(self, model_cls: Type[Model]) -> Iterator[Tuple[Optional[str], ForeignObjectRel]]:
for relation in model_cls._meta.get_fields():
if isinstance(relation, ForeignObjectRel):
yield relation.get_accessor_name(), relation
def get_field_lookup_exact_type(self, api: TypeChecker, field: Union[Field, ForeignObjectRel]) -> MypyType:
if isinstance(field, (RelatedField, ForeignObjectRel)):

View File

@@ -10,11 +10,11 @@ from mypy.mro import calculate_mro
from mypy.nodes import (
Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTable, SymbolTableNode,
TypeInfo, Var,
CallExpr, Context, PlaceholderNode, FuncDef, FakeInfo)
from mypy.plugin import DynamicClassDefContext, ClassDefContext
CallExpr, Context, PlaceholderNode, FuncDef, FakeInfo, OverloadedFuncDef, Decorator)
from mypy.plugin import DynamicClassDefContext, ClassDefContext, AttributeContext, MethodContext
from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, Instance, NoneTyp, TypeType
from mypy.semanal import SemanticAnalyzer, is_valid_replacement, is_same_symbol
from mypy.types import AnyType, Instance, NoneTyp, TypeType, ProperType, CallableType
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny, UnionType
from mypy.typetraverser import TypeTraverserVisitor
@@ -38,8 +38,25 @@ class DjangoPluginCallback:
self.plugin = plugin
self.django_context = plugin.django_context
# def lookup_fully_qualified(self, fullname: str) -> Optional[SymbolTableNode]:
# return self.plugin.lookup_fully_qualified(fullname)
def new_typeinfo(self, name: str, bases: List[Instance]) -> TypeInfo:
class_def = ClassDef(name, Block([]))
class_def.fullname = self.qualified_name(name)
info = TypeInfo(SymbolTable(), class_def, self.get_current_module().fullname)
info.bases = bases
calculate_mro(info)
info.metaclass_type = info.calculate_metaclass_type()
class_def.info = info
return info
@abstractmethod
def get_current_module(self) -> MypyFile:
raise NotImplementedError()
@abstractmethod
def qualified_name(self, name: str) -> str:
raise NotImplementedError()
class SemanalPluginCallback(DjangoPluginCallback):
@@ -58,6 +75,12 @@ class SemanalPluginCallback(DjangoPluginCallback):
print(f'LOG: defer: {self.build_defer_error_message(reason)}')
return True
def get_current_module(self) -> MypyFile:
return self.semanal_api.cur_mod_node
def qualified_name(self, name: str) -> str:
return self.semanal_api.qualified_name(name)
def lookup_typeinfo_or_defer(self, fullname: str, *,
deferral_context: Optional[Context] = None,
reason_for_defer: Optional[str] = None) -> Optional[TypeInfo]:
@@ -74,11 +97,12 @@ class SemanalPluginCallback(DjangoPluginCallback):
return sym.node
def new_typeinfo(self, name: str, bases: List[Instance]) -> TypeInfo:
def new_typeinfo(self, name: str, bases: List[Instance], module_fullname: Optional[str] = None) -> TypeInfo:
class_def = ClassDef(name, Block([]))
class_def.fullname = self.semanal_api.qualified_name(name)
info = TypeInfo(SymbolTable(), class_def, self.semanal_api.cur_mod_id)
info = TypeInfo(SymbolTable(), class_def,
module_fullname or self.get_current_module().fullname)
info.bases = bases
calculate_mro(info)
info.metaclass_type = info.calculate_metaclass_type()
@@ -86,6 +110,43 @@ class SemanalPluginCallback(DjangoPluginCallback):
class_def.info = info
return info
def add_symbol_table_node(self,
name: str,
symbol: SymbolTableNode,
symbol_table: Optional[SymbolTable] = None,
context: Optional[Context] = None,
can_defer: bool = True,
escape_comprehensions: bool = False) -> None:
""" Patched copy of SemanticAnalyzer.add_symbol_table_node(). """
names = symbol_table or self.semanal_api.current_symbol_table(escape_comprehensions=escape_comprehensions)
existing = names.get(name)
if isinstance(symbol.node, PlaceholderNode) and can_defer:
self.semanal_api.defer(context)
return None
if (existing is not None
and context is not None
and not is_valid_replacement(existing, symbol)):
# There is an existing node, so this may be a redefinition.
# If the new node points to the same node as the old one,
# or if both old and new nodes are placeholders, we don't
# need to do anything.
old = existing.node
new = symbol.node
if isinstance(new, PlaceholderNode):
# We don't know whether this is okay. Let's wait until the next iteration.
return False
if not is_same_symbol(old, new):
if isinstance(new, (FuncDef, Decorator, OverloadedFuncDef, TypeInfo)):
self.semanal_api.add_redefinition(names, name, symbol)
if not (isinstance(new, (FuncDef, Decorator))
and self.semanal_api.set_original_def(old, new)):
self.semanal_api.name_already_defined(name, context, existing)
elif name not in self.semanal_api.missing_names and '*' not in self.semanal_api.missing_names:
names[name] = symbol
self.progress = True
return None
raise new_helpers.SymbolAdditionNotPossible()
# def add_symbol_table_node_or_defer(self, name: str, sym: SymbolTableNode) -> bool:
# return self.semanal_api.add_symbol_table_node(name, sym,
# context=self.semanal_api.cur_mod_node)
@@ -119,20 +180,6 @@ class SemanalPluginCallback(DjangoPluginCallback):
self.semanal_api.add_imported_symbol(name, sym, context=self.semanal_api.cur_mod_node)
class UnimportedTypesVisitor(TypeTraverserVisitor):
def visit_union_type(self, t: UnionType) -> None:
super().visit_union_type(t)
union_sym = currently_imported_symbols.get('Union')
if union_sym is None:
# TODO: check if it's exactly typing.Union
import_symbol_from_source('Union')
def visit_type_type(self, t: TypeType) -> None:
super().visit_type_type(t)
type_sym = currently_imported_symbols.get('Union')
if type_sym is None:
# TODO: check if it's exactly typing.Type
import_symbol_from_source('Type')
def visit_instance(self, t: Instance) -> None:
super().visit_instance(t)
if isinstance(t.type, FakeInfo):
@@ -140,7 +187,6 @@ class SemanalPluginCallback(DjangoPluginCallback):
type_name = t.type.name
sym = currently_imported_symbols.get(type_name)
if sym is None:
# TODO: check if it's exactly typing.Type
import_symbol_from_source(type_name)
signature_node.type.accept(UnimportedTypesVisitor())
@@ -202,11 +248,13 @@ class DynamicClassPluginCallback(SemanalPluginCallback):
class ClassDefPluginCallback(SemanalPluginCallback):
reason: Expression
class_defn: ClassDef
ctx: ClassDefContext
def __call__(self, ctx: ClassDefContext) -> None:
self.reason = ctx.reason
self.class_defn = ctx.cls
self.semanal_api = cast(SemanticAnalyzer, ctx.api)
self.ctx = ctx
self.modify_class_defn()
@abstractmethod
@@ -214,6 +262,64 @@ class ClassDefPluginCallback(SemanalPluginCallback):
raise NotImplementedError
class TypeCheckerPluginCallback(DjangoPluginCallback):
type_checker: TypeChecker
def get_current_module(self) -> MypyFile:
current_module = None
for item in reversed(self.type_checker.scope.stack):
if isinstance(item, MypyFile):
current_module = item
break
assert current_module is not None
return current_module
def qualified_name(self, name: str) -> str:
return self.type_checker.scope.stack[-1].fullname + '.' + name
def lookup_typeinfo(self, fullname: str) -> Optional[TypeInfo]:
sym = self.plugin.lookup_fully_qualified(fullname)
if sym is None or sym.node is None:
return None
if not isinstance(sym.node, TypeInfo):
raise ValueError(f'{fullname!r} does not correspond to TypeInfo')
return sym.node
class GetMethodPluginCallback(TypeCheckerPluginCallback):
callee_type: Instance
ctx: MethodContext
def __call__(self, ctx: MethodContext) -> MypyType:
self.type_checker = ctx.api
assert isinstance(ctx.type, CallableType)
self.callee_type = ctx.type.ret_type
self.ctx = ctx
return self.get_method_return_type()
@abstractmethod
def get_method_return_type(self) -> MypyType:
raise NotImplementedError
class GetAttributeCallback(TypeCheckerPluginCallback):
obj_type: ProperType
default_attr_type: MypyType
error_context: MemberExpr
name: str
def __call__(self, ctx: AttributeContext) -> MypyType:
self.ctx = ctx
self.type_checker = ctx.api
self.obj_type = ctx.type
self.default_attr_type = ctx.default_attr_type
self.error_context = ctx.context
assert isinstance(self.error_context, MemberExpr)
self.name = self.error_context.name
return self.default_attr_type
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {})

View File

@@ -17,11 +17,10 @@ from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import (
fields, forms, init_create, meta, querysets, request, settings,
)
from mypy_django_plugin.transformers.managers import (
create_manager_class_from_as_manager_method, instantiate_anonymous_queryset_from_as_manager)
from mypy_django_plugin.transformers.models import process_model_class
from mypy_django_plugin.transformers2.dynamic_managers import CreateNewManagerClassFrom_FromQuerySet
from mypy_django_plugin.transformers2.models import ModelCallback
from mypy_django_plugin.transformers2.related_managers import GetRelatedManagerCallback
def transform_model_class(ctx: ClassDefContext,
@@ -176,10 +175,6 @@ class NewSemanalDjangoPlugin(Plugin):
if fullname == 'django.contrib.auth.get_user_model':
return partial(settings.get_user_model_hook, django_context=self.django_context)
# manager_bases = self._get_current_manager_bases()
# if fullname in manager_bases:
# return querysets.determine_proper_manager_type
info = self._get_typeinfo_or_none(fullname)
if info:
if info.has_base(fullnames.FIELD_FULLNAME):
@@ -217,11 +212,6 @@ class NewSemanalDjangoPlugin(Plugin):
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)
if method_name == 'as_manager':
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return instantiate_anonymous_queryset_from_as_manager
manager_classes = self._get_current_manager_bases()
if class_fullname in manager_classes and method_name == 'create':
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
@@ -253,6 +243,10 @@ class NewSemanalDjangoPlugin(Plugin):
info = self._get_typeinfo_or_none(class_name)
if info and info.has_base(fullnames.HTTPREQUEST_CLASS_FULLNAME) and attr_name == 'user':
return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context)
if info and info.has_base(fullnames.MODEL_CLASS_FULLNAME):
return GetRelatedManagerCallback(self)
return None
def get_dynamic_class_hook(self, fullname: str
@@ -263,12 +257,6 @@ class NewSemanalDjangoPlugin(Plugin):
if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
return CreateNewManagerClassFrom_FromQuerySet(self)
if fullname.endswith('as_manager'):
class_name, _, _ = fullname.rpartition('.')
info = self._get_typeinfo_or_none(class_name)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return create_manager_class_from_as_manager_method
return None

View File

@@ -230,102 +230,3 @@ def add_symbol_table_node(api: SemanticAnalyzer,
return True
return False
class CreateNewManagerClassFrom_AsManager(helpers.DynamicClassPluginCallback):
def create_new_dynamic_class(self) -> None:
pass
def create_manager_class_from_as_manager_method(ctx: DynamicClassDefContext) -> None:
semanal_api = sem_helpers.get_semanal_api(ctx)
try:
queryset_info = resolve_callee_info_or_exception(ctx)
django_manager_info = resolve_django_manager_info_or_exception(ctx)
except sem_helpers.IncompleteDefnError:
if not semanal_api.final_iteration:
semanal_api.defer()
return
else:
raise
generic_param: MypyType = AnyType(TypeOfAny.explicit)
generic_param_name = 'Any'
if (semanal_api.scope.classes
and semanal_api.scope.classes[-1].has_base(fullnames.MODEL_CLASS_FULLNAME)):
info = semanal_api.scope.classes[-1] # type: TypeInfo
generic_param = Instance(info, [])
generic_param_name = info.name
new_manager_class_name = queryset_info.name + '_AsManager_' + generic_param_name
new_manager_info = helpers.new_typeinfo(new_manager_class_name,
bases=[Instance(django_manager_info, [generic_param])],
module_name=semanal_api.cur_mod_id)
new_manager_info.set_line(ctx.call)
record_new_manager_info_fullname_into_metadata(ctx,
new_manager_info.fullname,
django_manager_info,
queryset_info,
django_manager_info)
class_def_context = ClassDefContext(cls=new_manager_info.defn,
reason=ctx.call, api=semanal_api)
self_type = Instance(new_manager_info, [AnyType(TypeOfAny.explicit)])
try:
for name, method_node in iter_all_custom_queryset_methods(queryset_info):
sem_helpers.copy_method_or_incomplete_defn_exception(class_def_context,
self_type,
new_method_name=name,
method_node=method_node)
except sem_helpers.IncompleteDefnError:
if not semanal_api.final_iteration:
semanal_api.defer()
return
else:
raise
new_manager_sym = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
# context=None - forcibly replace old node
added = add_symbol_table_node(semanal_api, new_manager_class_name, new_manager_sym,
context=None,
symbol_table=semanal_api.globals)
if added:
# replace all references to the old manager Var everywhere
for _, module in semanal_api.modules.items():
if module.fullname != semanal_api.cur_mod_id:
for sym_name, sym in module.names.items():
if sym.fullname == new_manager_info.fullname:
module.names[sym_name] = new_manager_sym.copy()
# we need another iteration to process methods
if (not added
and not semanal_api.final_iteration):
semanal_api.defer()
def instantiate_anonymous_queryset_from_as_manager(ctx: MethodContext) -> MypyType:
api = chk_helpers.get_typechecker_api(ctx)
django_manager_info = helpers.lookup_fully_qualified_typeinfo(api, fullnames.MANAGER_CLASS_FULLNAME)
assert django_manager_info is not None
assert isinstance(ctx.type, CallableType)
assert isinstance(ctx.type.ret_type, Instance)
queryset_info = ctx.type.ret_type.type
gen_name = django_manager_info.name + 'From' + queryset_info.name
gen_fullname = 'django.db.models.manager' + '.' + gen_name
metadata = get_generated_managers_metadata(django_manager_info)
if gen_fullname not in metadata:
raise ValueError(f'{gen_fullname!r} is not present in generated managers list')
module_name, _, class_name = metadata[gen_fullname].rpartition('.')
current_module = helpers.get_current_module(api)
assert module_name == current_module.fullname
generated_manager_info = current_module.names[class_name].node
assert isinstance(generated_manager_info, TypeInfo)
return Instance(generated_manager_info, [])

View File

@@ -4,11 +4,17 @@ from mypy.checker import gen_unique_name
from mypy.nodes import NameExpr, TypeInfo, SymbolTableNode, StrExpr
from mypy.types import Type as MypyType, TypeVarType, TypeVarDef, Instance
from mypy_django_plugin.lib import helpers, fullnames, chk_helpers, sem_helpers
from mypy_django_plugin.lib import helpers, fullnames
from mypy_django_plugin.transformers.managers import iter_all_custom_queryset_methods
class CreateNewManagerClassFrom_FromQuerySet(helpers.DynamicClassPluginCallback):
def set_manager_mapping(self, runtime_manager_fullname: str, generated_manager_fullname: str) -> None:
base_model_info = self.lookup_typeinfo_or_defer(fullnames.MODEL_CLASS_FULLNAME)
assert base_model_info is not None
managers_metadata = base_model_info.metadata.setdefault('managers', {})
managers_metadata[runtime_manager_fullname] = generated_manager_fullname
def create_typevar_in_current_module(self, name: str,
upper_bound: Optional[MypyType] = None) -> TypeVarDef:
tvar_name = gen_unique_name(name, self.semanal_api.globals)
@@ -48,19 +54,20 @@ class CreateNewManagerClassFrom_FromQuerySet(helpers.DynamicClassPluginCallback)
parent_manager_type = Instance(callee_manager_info, [model_tvar_type])
# instantiate with a proper model, Manager[MyModel], filling all Manager type vars in process
queryset_type = Instance(passed_queryset_info, [Instance(base_model_info, [])])
new_manager_info = self.new_typeinfo(self.class_name,
bases=[parent_manager_type])
bases=[queryset_type, parent_manager_type])
new_manager_info.defn.type_vars = [model_tvar_defn]
new_manager_info.type_vars = [model_tvar_defn.name]
new_manager_info.set_line(self.call_expr)
# copy methods from passed_queryset_info with self type replaced
self_type = Instance(new_manager_info, [model_tvar_type])
for name, method_node in iter_all_custom_queryset_methods(passed_queryset_info):
self.add_method_from_signature(method_node,
name,
self_type,
new_manager_info.defn)
# self_type = Instance(new_manager_info, [model_tvar_type])
# for name, method_node in iter_all_custom_queryset_methods(passed_queryset_info):
# self.add_method_from_signature(method_node,
# name,
# self_type,
# new_manager_info.defn)
new_manager_sym = SymbolTableNode(self.semanal_api.current_symbol_kind(),
new_manager_info,
@@ -75,5 +82,5 @@ class CreateNewManagerClassFrom_FromQuerySet(helpers.DynamicClassPluginCallback)
runtime_manager_class_name = class_name_arg.value
new_manager_name = runtime_manager_class_name or (callee_manager_info.name + 'From' + queryset_class_name)
django_generated_manager_name = 'django.db.models.manager.' + new_manager_name
base_model_info.metadata.setdefault('managers', {})[django_generated_manager_name] = new_manager_info.fullname
self.set_manager_mapping(f'django.db.models.manager.{new_manager_name}',
new_manager_info.fullname)

View File

@@ -3,16 +3,16 @@ from typing import Type, Optional
from django.db.models.base import Model
from django.db.models.fields.related import OneToOneField, ForeignKey
from django.db.models.fields.reverse_related import OneToOneRel, ManyToManyRel, ManyToOneRel
from mypy.checker import gen_unique_name
from mypy.nodes import TypeInfo, Var, SymbolTableNode, MDEF
from mypy.nodes import TypeInfo, Var, SymbolTableNode, MDEF, Argument, ARG_STAR2
from mypy.plugin import ClassDefContext
from mypy.plugins import common
from mypy.semanal import dummy_context
from mypy.types import Instance, TypeOfAny, AnyType
from mypy.types import Type as MypyType
from django.db import models
from mypy_django_plugin.lib import helpers, fullnames
from django.db.models.fields import DateField, DateTimeField
from mypy_django_plugin.lib import helpers, fullnames, sem_helpers
from mypy_django_plugin.transformers import fields
from mypy_django_plugin.transformers.fields import get_field_type
from mypy_django_plugin.transformers2 import new_helpers
@@ -116,76 +116,77 @@ class AddPrimaryKeyIfDoesNotExist(TransformModelClassCallback):
class AddRelatedManagersCallback(TransformModelClassCallback):
def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None:
for relation in self.django_context.get_model_relations(runtime_model_cls):
reverse_manager_name = relation.get_accessor_name()
for reverse_manager_name, relation in self.django_context.get_model_relations(runtime_model_cls):
if (reverse_manager_name is None
or reverse_manager_name in self.class_defn.info.names):
continue
related_model_cls = self.django_context.get_field_related_model_cls(relation)
if related_model_cls is None:
# could not find a referenced model (maybe invalid to= value, or GenericForeignKey)
continue
related_model_info = self.lookup_typeinfo_for_class_or_defer(related_model_cls)
if related_model_info is None:
continue
if isinstance(relation, OneToOneRel):
self.add_new_model_attribute(reverse_manager_name,
Instance(related_model_info, []))
elif isinstance(relation, (ManyToOneRel, ManyToManyRel)):
related_manager_info = self.lookup_typeinfo_or_defer(fullnames.RELATED_MANAGER_CLASS)
if related_manager_info is None:
if not self.defer_till_next_iteration(self.class_defn,
reason=f'{fullnames.RELATED_MANAGER_CLASS!r} is not available for lookup'):
raise TypeInfoNotFound(fullnames.RELATED_MANAGER_CLASS)
continue
# get type of default_manager for model
default_manager_fullname = helpers.get_class_fullname(related_model_cls._meta.default_manager.__class__)
reason_for_defer = (f'Trying to lookup default_manager {default_manager_fullname!r} '
f'of model {helpers.get_class_fullname(related_model_cls)!r}')
default_manager_info = self.lookup_typeinfo_or_defer(default_manager_fullname,
reason_for_defer=reason_for_defer)
if default_manager_info is None:
continue
default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])])
# related_model_cls._meta.default_manager.__class__
# # we're making a subclass of 'objects', need to have it defined
# if 'objects' not in related_model_info.names:
# if not self.defer_till_next_iteration(self.class_defn,
# reason=f"'objects' manager is not yet defined on {related_model_info.fullname!r}"):
# raise AttributeNotFound(self.class_defn.info, 'objects')
# continue
related_manager_type = Instance(related_manager_info,
[Instance(related_model_info, [])])
#
# objects_sym = related_model_info.names['objects']
# default_manager_type = objects_sym.type
# if default_manager_type is None:
# # dynamic base class, extract from django_context
# default_manager_cls = related_model_cls._meta.default_manager.__class__
# default_manager_info = self.lookup_typeinfo_for_class_or_defer(default_manager_cls)
# if default_manager_info is None:
# continue
# default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])])
if (not isinstance(default_manager_type, Instance)
or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME):
# if not defined or trivial -> just return RelatedManager[Model]
self.add_new_model_attribute(reverse_manager_name, related_manager_type)
continue
# make anonymous class
name = gen_unique_name(related_model_cls.__name__ + '_' + 'RelatedManager',
self.semanal_api.current_symbol_table())
bases = [related_manager_type, default_manager_type]
new_manager_info = self.new_typeinfo(name, bases)
self.add_new_model_attribute(reverse_manager_name, Instance(new_manager_info, []))
self.add_new_model_attribute(reverse_manager_name, AnyType(TypeOfAny.implementation_artifact))
#
# related_model_cls = self.django_context.get_field_related_model_cls(relation)
# if related_model_cls is None:
# # could not find a referenced model (maybe invalid to= value, or GenericForeignKey)
# continue
#
# related_model_info = self.lookup_typeinfo_for_class_or_defer(related_model_cls)
# if related_model_info is None:
# continue
#
# if isinstance(relation, OneToOneRel):
# self.add_new_model_attribute(reverse_manager_name,
# Instance(related_model_info, []))
# elif isinstance(relation, (ManyToOneRel, ManyToManyRel)):
# related_manager_info = self.lookup_typeinfo_or_defer(fullnames.RELATED_MANAGER_CLASS)
# if related_manager_info is None:
# if not self.defer_till_next_iteration(self.class_defn,
# reason=f'{fullnames.RELATED_MANAGER_CLASS!r} is not available for lookup'):
# raise TypeInfoNotFound(fullnames.RELATED_MANAGER_CLASS)
# continue
#
# # get type of default_manager for model
# default_manager_fullname = helpers.get_class_fullname(related_model_cls._meta.default_manager.__class__)
# reason_for_defer = (f'Trying to lookup default_manager {default_manager_fullname!r} '
# f'of model {helpers.get_class_fullname(related_model_cls)!r}')
# default_manager_info = self.lookup_typeinfo_or_defer(default_manager_fullname,
# reason_for_defer=reason_for_defer)
# if default_manager_info is None:
# continue
#
# default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])])
#
# # related_model_cls._meta.default_manager.__class__
# # # we're making a subclass of 'objects', need to have it defined
# # if 'objects' not in related_model_info.names:
# # if not self.defer_till_next_iteration(self.class_defn,
# # reason=f"'objects' manager is not yet defined on {related_model_info.fullname!r}"):
# # raise AttributeNotFound(self.class_defn.info, 'objects')
# # continue
#
# related_manager_type = Instance(related_manager_info,
# [Instance(related_model_info, [])])
# #
# # objects_sym = related_model_info.names['objects']
# # default_manager_type = objects_sym.type
# # if default_manager_type is None:
# # # dynamic base class, extract from django_context
# # default_manager_cls = related_model_cls._meta.default_manager.__class__
# # default_manager_info = self.lookup_typeinfo_for_class_or_defer(default_manager_cls)
# # if default_manager_info is None:
# # continue
# # default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])])
#
# if (not isinstance(default_manager_type, Instance)
# or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME):
# # if not defined or trivial -> just return RelatedManager[Model]
# self.add_new_model_attribute(reverse_manager_name, related_manager_type)
# continue
#
# # make anonymous class
# name = gen_unique_name(related_model_cls.__name__ + '_' + 'RelatedManager',
# self.semanal_api.current_symbol_table())
# bases = [related_manager_type, default_manager_type]
# new_manager_info = self.new_typeinfo(name, bases)
# self.add_new_model_attribute(reverse_manager_name, Instance(new_manager_info, []))
class AddForeignPrimaryKeys(TransformModelClassCallback):
@@ -222,6 +223,69 @@ class AddForeignPrimaryKeys(TransformModelClassCallback):
self.add_new_model_attribute(rel_pk_field_name, field_type)
class InjectAnyAsBaseForNestedMeta(TransformModelClassCallback):
"""
Replaces
class MyModel(models.Model):
class Meta:
pass
with
class MyModel(models.Model):
class Meta(Any):
pass
to get around incompatible Meta inner classes for different models.
"""
def modify_class_defn(self) -> None:
meta_node = sem_helpers.get_nested_meta_node_for_current_class(self.class_defn.info)
if meta_node is None:
return None
meta_node.fallback_to_any = True
class AddMetaOptionsAttribute(TransformModelClassCallback):
def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None:
if '_meta' not in self.class_defn.info.names:
options_info = self.lookup_typeinfo_or_defer(fullnames.OPTIONS_CLASS_FULLNAME)
if options_info is not None:
self.add_new_model_attribute('_meta',
Instance(options_info, [
Instance(self.class_defn.info, [])
]))
class AddExtraFieldMethods(TransformModelClassCallback):
def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None:
# get_FOO_display for choices
for field in self.django_context.get_model_fields(runtime_model_cls):
if field.choices:
info = self.lookup_typeinfo_or_defer('builtins.str')
return_type = Instance(info, [])
common.add_method(self.ctx,
name='get_{}_display'.format(field.attname),
args=[],
return_type=return_type)
# get_next_by, get_previous_by for Date, DateTime
for field in self.django_context.get_model_fields(runtime_model_cls):
if isinstance(field, (DateField, DateTimeField)) and not field.null:
return_type = Instance(self.class_defn.info, [])
common.add_method(self.ctx,
name='get_next_by_{}'.format(field.attname),
args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)),
AnyType(TypeOfAny.explicit),
initializer=None,
kind=ARG_STAR2)],
return_type=return_type)
common.add_method(self.ctx,
name='get_previous_by_{}'.format(field.attname),
args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)),
AnyType(TypeOfAny.explicit),
initializer=None,
kind=ARG_STAR2)],
return_type=return_type)
class ModelCallback(helpers.ClassDefPluginCallback):
def __call__(self, ctx: ClassDefContext) -> None:
callback_classes = [
@@ -230,6 +294,9 @@ class ModelCallback(helpers.ClassDefPluginCallback):
AddForeignPrimaryKeys,
AddDefaultManagerCallback,
AddRelatedManagersCallback,
InjectAnyAsBaseForNestedMeta,
AddMetaOptionsAttribute,
AddExtraFieldMethods,
]
for callback_cls in callback_classes:
callback = callback_cls(self.plugin)

View File

@@ -22,5 +22,9 @@ class NameNotFound(IncompleteDefnError):
super().__init__(f'Could not find {name!r} in the current activated namespaces')
class SymbolAdditionNotPossible(Exception):
pass
def get_class_fullname(klass: type) -> str:
return klass.__module__ + '.' + klass.__qualname__

View File

@@ -0,0 +1,69 @@
from mypy.checker import gen_unique_name
from mypy.plugin import AttributeContext
from mypy.types import Instance
from mypy.types import Type as MypyType
from django.db.models.fields.reverse_related import ForeignObjectRel, OneToOneRel, ManyToOneRel, ManyToManyRel
from mypy_django_plugin.lib import helpers, fullnames
from mypy_django_plugin.lib.helpers import GetAttributeCallback
class GetRelatedManagerCallback(GetAttributeCallback):
obj_type: Instance
def get_related_manager_type(self, relation: ForeignObjectRel) -> MypyType:
related_model_cls = self.django_context.get_field_related_model_cls(relation)
if related_model_cls is None:
# could not find a referenced model (maybe invalid to= value, or GenericForeignKey)
# TODO: show error
return self.default_attr_type
related_model_info = self.lookup_typeinfo(helpers.get_class_fullname(related_model_cls))
if related_model_info is None:
# TODO: show error
return self.default_attr_type
if isinstance(relation, OneToOneRel):
return Instance(related_model_info, [])
elif isinstance(relation, (ManyToOneRel, ManyToManyRel)):
related_manager_info = self.lookup_typeinfo(fullnames.RELATED_MANAGER_CLASS)
if related_manager_info is None:
return self.default_attr_type
# get type of default_manager for model
default_manager_fullname = helpers.get_class_fullname(related_model_cls._meta.default_manager.__class__)
default_manager_info = self.lookup_typeinfo(default_manager_fullname)
if default_manager_info is None:
return self.default_attr_type
default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])])
related_manager_type = Instance(related_manager_info,
[Instance(related_model_info, [])])
if (not isinstance(default_manager_type, Instance)
or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME):
# if not defined or trivial -> just return RelatedManager[Model]
return related_manager_type
# make anonymous class
name = gen_unique_name(related_model_cls.__name__ + '_' + 'RelatedManager',
self.obj_type.type.names)
bases = [related_manager_type, default_manager_type]
new_manager_info = self.new_typeinfo(name, bases)
return Instance(new_manager_info, [])
def __call__(self, ctx: AttributeContext):
super().__call__(ctx)
assert isinstance(self.obj_type, Instance)
model_fullname = self.obj_type.type.fullname
model_cls = self.django_context.get_model_class_by_fullname(model_fullname)
if model_cls is None:
return self.default_attr_type
for reverse_manager_name, relation in self.django_context.get_model_relations(model_cls):
if reverse_manager_name == self.name:
return self.get_related_manager_type(relation)
return self.default_attr_type

View File

@@ -653,7 +653,7 @@
- case: related_manager_is_a_subclass_of_default_manager
main: |
from myapp.models import User
reveal_type(User().orders) # N: Revealed type is 'myapp.models.User.Order_RelatedManager'
reveal_type(User().orders) # N: Revealed type is 'main.Order_RelatedManager'
reveal_type(User().orders.get()) # N: Revealed type is 'myapp.models.Order*'
reveal_type(User().orders.manager_method()) # N: Revealed type is 'builtins.int'
installed_apps:

View File

@@ -1,95 +0,0 @@
- case: anonymous_queryset_from_as_manager_inside_model
main: |
from myapp.models import MyModel
reveal_type(MyModel.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_MyModel'
reveal_type(MyModel.objects.get()) # N: Revealed type is 'myapp.models.MyModel*'
reveal_type(MyModel.objects.queryset_method) # N: Revealed type is 'def () -> builtins.int'
reveal_type(MyModel.objects.queryset_method()) # N: Revealed type is 'builtins.int'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyQuerySet(models.QuerySet):
def queryset_method(self) -> int:
pass
class MyModel(models.Model):
objects = MyQuerySet.as_manager()
- case: two_invocations_parametrized_with_different_models
main: |
from myapp.models import User, Blog
reveal_type(User.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_User'
reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*'
reveal_type(User.objects.queryset_method) # N: Revealed type is 'def () -> builtins.int'
reveal_type(User.objects.queryset_method()) # N: Revealed type is 'builtins.int'
reveal_type(Blog.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_Blog'
reveal_type(Blog.objects.get()) # N: Revealed type is 'myapp.models.Blog*'
reveal_type(Blog.objects.queryset_method) # N: Revealed type is 'def () -> builtins.int'
reveal_type(Blog.objects.queryset_method()) # N: Revealed type is 'builtins.int'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyQuerySet(models.QuerySet):
def queryset_method(self) -> int:
pass
class User(models.Model):
objects = MyQuerySet.as_manager()
class Blog(models.Model):
objects = MyQuerySet.as_manager()
- case: as_manager_outside_model_parametrized_with_any
main: |
from myapp.models import NotModel, outside_objects
reveal_type(NotModel.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_Any'
reveal_type(NotModel.objects.get()) # N: Revealed type is 'Any'
reveal_type(outside_objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_Any'
reveal_type(outside_objects.get()) # N: Revealed type is 'Any'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyQuerySet(models.QuerySet):
def queryset_method(self) -> int:
pass
outside_objects = MyQuerySet.as_manager()
class NotModel:
objects = MyQuerySet.as_manager()
- case: test_as_manager_without_name_to_bind_in_different_files
main: |
from myapp.models import MyQuerySet
reveal_type(MyQuerySet.as_manager()) # N: Revealed type is 'Any'
reveal_type(MyQuerySet.as_manager().get()) # N: Revealed type is 'Any'
reveal_type(MyQuerySet.as_manager().mymethod()) # N: Revealed type is 'Any'
from myapp import helpers
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyQuerySet(models.QuerySet):
def mymethod(self) -> int:
pass
class MyModel(models.Model):
objects = MyQuerySet.as_manager()
- path: myapp/helpers.py
content: |
from myapp.models import MyQuerySet
MyQuerySet.as_manager()

View File

@@ -17,11 +17,15 @@
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from typing import TypeVar
from django.db import models
from django.db.models.manager import BaseManager, Manager
from mypy_django_plugin.lib import generics
class ModelQuerySet(models.QuerySet):
generics.make_classes_generic(models.QuerySet)
_M = TypeVar('_M', bound=models.Model)
class ModelQuerySet(models.QuerySet[_M]):
def queryset_method(self) -> str:
return 'hello'

View File

@@ -0,0 +1,3 @@
digraph {
FuncDef [label="My FuncDef"]
}

Binary file not shown.