refactor, fix method copying

This commit is contained in:
Maxim Kurnikov
2020-01-04 13:36:11 +03:00
parent 7ba578f6b2
commit 13e40ab4a1
17 changed files with 706 additions and 427 deletions

View File

@@ -21,7 +21,7 @@ from mypy.types import AnyType, Instance
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy.types import TypeOfAny, UnionType from mypy.types import TypeOfAny, UnionType
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, helpers, chk_helpers
try: try:
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
@@ -356,11 +356,11 @@ class DjangoContext:
return AnyType(TypeOfAny.explicit) return AnyType(TypeOfAny.explicit)
if lookup_cls is None or isinstance(lookup_cls, Exact): if lookup_cls is None or isinstance(lookup_cls, Exact):
return self.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field) return self.get_field_lookup_exact_type(chk_helpers.get_typechecker_api(ctx), field)
assert lookup_cls is not None assert lookup_cls is not None
lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls) lookup_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), lookup_cls)
if lookup_info is None: if lookup_info is None:
return AnyType(TypeOfAny.explicit) return AnyType(TypeOfAny.explicit)
@@ -370,7 +370,7 @@ class DjangoContext:
# if it's Field, consider lookup_type a __get__ of current field # if it's Field, consider lookup_type a __get__ of current field
if (isinstance(lookup_type, Instance) if (isinstance(lookup_type, Instance)
and lookup_type.type.fullname == fullnames.FIELD_FULLNAME): and lookup_type.type.fullname == fullnames.FIELD_FULLNAME):
field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__) field_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), field.__class__)
if field_info is None: if field_info is None:
return AnyType(TypeOfAny.explicit) return AnyType(TypeOfAny.explicit)
lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',

View File

@@ -0,0 +1,112 @@
from typing import OrderedDict, List, Optional, Dict, Set, Union
from mypy import checker
from mypy.checker import TypeChecker
from mypy.nodes import MypyFile, TypeInfo, Var, MDEF, SymbolTableNode, GDEF, Expression
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext, AttributeContext
from mypy.types import Type as MypyType, Instance, TupleType, TypeOfAny, AnyType, TypedDictType
from mypy_django_plugin.lib import helpers
def add_new_class_for_current_module(current_module: MypyFile,
name: str,
bases: List[Instance],
fields: Optional[Dict[str, MypyType]] = None
) -> TypeInfo:
new_class_unique_name = checker.gen_unique_name(name, current_module.names)
new_typeinfo = helpers.new_typeinfo(new_class_unique_name,
bases=bases,
module_name=current_module.fullname)
# new_typeinfo = helpers.make_new_typeinfo_in_current_module(new_class_unique_name,
# bases=bases,
# current_module_fullname=current_module.fullname)
# add fields
if fields:
for field_name, field_type in fields.items():
var = Var(field_name, type=field_type)
var.info = new_typeinfo
var._fullname = new_typeinfo.fullname + '.' + field_name
new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True)
current_module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)
current_module.defs.append(new_typeinfo.defn)
return new_typeinfo
def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'Dict[str, MypyType]') -> TupleType:
current_module = helpers.get_current_module(api)
namedtuple_info = add_new_class_for_current_module(current_module, name,
bases=[api.named_generic_type('typing.NamedTuple', [])],
fields=fields)
return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, []))
def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType:
# fallback for tuples is any builtins.tuple instance
fallback = api.named_generic_type('builtins.tuple',
[AnyType(TypeOfAny.special_form)])
return TupleType(fields, fallback=fallback)
def make_oneoff_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyType]',
required_keys: Set[str]) -> TypedDictType:
object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
return typed_dict_type
def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker:
if not isinstance(ctx.api, TypeChecker):
raise ValueError('Not a TypeChecker')
return ctx.api
def check_types_compatible(ctx: Union[FunctionContext, MethodContext],
*, expected_type: MypyType, actual_type: MypyType, error_message: str) -> None:
api = get_typechecker_api(ctx)
api.check_subtype(actual_type, expected_type,
ctx.context, error_message,
'got', 'expected')
def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]:
"""
Return the expression for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
args = ctx.args[idx]
if len(args) != 1:
# Either an error or no value passed.
return None
return args[0]
def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]:
"""Return the type for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
arg_types = ctx.arg_types[idx]
if len(arg_types) != 1:
# Either an error or no value passed.
return None
return arg_types[0]
def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None:
# type=: type of the variable itself
var = Var(name=name, type=sym_type)
# var.info: type of the object variable is bound to
var.info = info
var._fullname = info.fullname + '.' + name
var.is_initialized_in_class = True
var.is_inferred = True
info.names[name] = SymbolTableNode(MDEF, var,
plugin_generated=True)

View File

@@ -1,41 +1,33 @@
from collections import OrderedDict
from typing import ( from typing import (
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Union,
) )
from django.db.models.fields import Field
from django.db.models.fields.related import RelatedField from django.db.models.fields.related import RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel from django.db.models.fields.reverse_related import ForeignObjectRel
from mypy import checker
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.mro import calculate_mro from mypy.mro import calculate_mro
from mypy.nodes import ( from mypy.nodes import (
GDEF, MDEF, Argument, Block, ClassDef, Expression, FuncDef, MemberExpr, MypyFile, NameExpr, PlaceholderNode, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode,
StrExpr, SymbolNode, SymbolTable, SymbolTableNode, TypeInfo, Var, SymbolTable, SymbolTableNode, TypeInfo, Var,
) )
from mypy.plugin import (
AttributeContext, CheckerPluginInterface, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext,
)
from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzer from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType from mypy.types import AnyType, Instance, NoneTyp
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy.types import TypedDictType, TypeOfAny, UnionType from mypy.types import TypeOfAny, UnionType
from django.db.models.fields import Field
from mypy_django_plugin.lib import fullnames from mypy_django_plugin.lib import fullnames
if TYPE_CHECKING: if TYPE_CHECKING:
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
AnyPluginAPI = Union[TypeChecker, SemanticAnalyzer]
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]: def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {}) return model_info.metadata.setdefault('django', {})
class IncompleteDefnException(Exception):
pass
def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]: def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]:
if '.' not in fullname: if '.' not in fullname:
return None return None
@@ -57,14 +49,14 @@ def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile])
return sym.node return sym.node
def lookup_fully_qualified_typeinfo(api: Union[TypeChecker, SemanticAnalyzer], fullname: str) -> Optional[TypeInfo]: def lookup_fully_qualified_typeinfo(api: AnyPluginAPI, fullname: str) -> Optional[TypeInfo]:
node = lookup_fully_qualified_generic(fullname, api.modules) node = lookup_fully_qualified_generic(fullname, api.modules)
if not isinstance(node, TypeInfo): if not isinstance(node, TypeInfo):
return None return None
return node return node
def lookup_class_typeinfo(api: TypeChecker, klass: type) -> Optional[TypeInfo]: def lookup_class_typeinfo(api: AnyPluginAPI, klass: type) -> Optional[TypeInfo]:
fullname = get_class_fullname(klass) fullname = get_class_fullname(klass)
field_info = lookup_fully_qualified_typeinfo(api, fullname) field_info = lookup_fully_qualified_typeinfo(api, fullname)
return field_info return field_info
@@ -79,36 +71,6 @@ def get_class_fullname(klass: type) -> str:
return klass.__module__ + '.' + klass.__qualname__ return klass.__module__ + '.' + klass.__qualname__
def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]:
"""
Return the expression for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
args = ctx.args[idx]
if len(args) != 1:
# Either an error or no value passed.
return None
return args[0]
def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]:
"""Return the type for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
arg_types = ctx.arg_types[idx]
if len(arg_types) != 1:
# Either an error or no value passed.
return None
return arg_types[0]
def make_optional(typ: MypyType) -> MypyType: def make_optional(typ: MypyType) -> MypyType:
return UnionType.make_union([typ, NoneTyp()]) return UnionType.make_union([typ, NoneTyp()])
@@ -153,7 +115,7 @@ def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is
return AnyType(TypeOfAny.explicit) return AnyType(TypeOfAny.explicit)
def get_field_lookup_exact_type(api: TypeChecker, field: Field) -> MypyType: def get_field_lookup_exact_type(api: AnyPluginAPI, field: Field) -> MypyType:
if isinstance(field, (RelatedField, ForeignObjectRel)): if isinstance(field, (RelatedField, ForeignObjectRel)):
lookup_type_class = field.related_model lookup_type_class = field.related_model
rel_model_info = lookup_class_typeinfo(api, lookup_type_class) rel_model_info = lookup_class_typeinfo(api, lookup_type_class)
@@ -168,44 +130,10 @@ def get_field_lookup_exact_type(api: TypeChecker, field: Field) -> MypyType:
is_nullable=field.null) is_nullable=field.null)
def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]: def get_current_module(api: AnyPluginAPI) -> MypyFile:
metaclass_sym = info.names.get('Meta') if isinstance(api, SemanticAnalyzer):
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo): return api.cur_mod_node
return metaclass_sym.node
return None
def add_new_class_for_module(module: MypyFile,
name: str,
bases: List[Instance],
fields: Optional[Dict[str, MypyType]] = None
) -> TypeInfo:
new_class_unique_name = checker.gen_unique_name(name, module.names)
# make new class expression
classdef = ClassDef(new_class_unique_name, Block([]))
classdef.fullname = module.fullname + '.' + new_class_unique_name
# make new TypeInfo
new_typeinfo = TypeInfo(SymbolTable(), classdef, module.fullname)
new_typeinfo.bases = bases
calculate_mro(new_typeinfo)
new_typeinfo.calculate_metaclass_type()
# add fields
if fields:
for field_name, field_type in fields.items():
var = Var(field_name, type=field_type)
var.info = new_typeinfo
var._fullname = new_typeinfo.fullname + '.' + field_name
new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True)
classdef.info = new_typeinfo
module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)
return new_typeinfo
def get_current_module(api: TypeChecker) -> MypyFile:
current_module = None current_module = None
for item in reversed(api.scope.stack): for item in reversed(api.scope.stack):
if isinstance(item, MypyFile): if isinstance(item, MypyFile):
@@ -215,21 +143,6 @@ def get_current_module(api: TypeChecker) -> MypyFile:
return current_module return current_module
def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType:
current_module = get_current_module(api)
namedtuple_info = add_new_class_for_module(current_module, name,
bases=[api.named_generic_type('typing.NamedTuple', [])],
fields=fields)
return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, []))
def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType:
# fallback for tuples is any builtins.tuple instance
fallback = api.named_generic_type('builtins.tuple',
[AnyType(TypeOfAny.special_form)])
return TupleType(fields, fallback=fallback)
def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType: def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
if isinstance(typ, UnionType): if isinstance(typ, UnionType):
converted_items = [] converted_items = []
@@ -252,13 +165,6 @@ def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
return typ return typ
def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyType]',
required_keys: Set[str]) -> TypedDictType:
object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
return typed_dict_type
def resolve_string_attribute_value(attr_expr: Expression, django_context: 'DjangoContext') -> Optional[str]: def resolve_string_attribute_value(attr_expr: Expression, django_context: 'DjangoContext') -> Optional[str]:
if isinstance(attr_expr, StrExpr): if isinstance(attr_expr, StrExpr):
return attr_expr.value return attr_expr.value
@@ -272,104 +178,25 @@ def resolve_string_attribute_value(attr_expr: Expression, django_context: 'Djang
return None return None
def get_semanal_api(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> SemanticAnalyzer: def is_subclass_of_model(info: TypeInfo, django_context: 'DjangoContext') -> bool:
if not isinstance(ctx.api, SemanticAnalyzer):
raise ValueError('Not a SemanticAnalyzer')
return ctx.api
def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker:
if not isinstance(ctx.api, TypeChecker):
raise ValueError('Not a TypeChecker')
return ctx.api
def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool:
return (info.fullname in django_context.all_registered_model_class_fullnames return (info.fullname in django_context.all_registered_model_class_fullnames
or info.has_base(fullnames.MODEL_CLASS_FULLNAME)) or info.has_base(fullnames.MODEL_CLASS_FULLNAME))
def check_types_compatible(ctx: Union[FunctionContext, MethodContext], def new_typeinfo(name: str,
*, expected_type: MypyType, actual_type: MypyType, error_message: str) -> None: *,
api = get_typechecker_api(ctx) bases: List[Instance],
api.check_subtype(actual_type, expected_type, module_name: str) -> TypeInfo:
ctx.context, error_message, """
'got', 'expected') Construct new TypeInfo instance. Cannot be used for nested classes.
"""
class_def = ClassDef(name, Block([]))
class_def.fullname = module_name + '.' + name
info = TypeInfo(SymbolTable(), class_def, module_name)
info.bases = bases
calculate_mro(info)
info.metaclass_type = info.calculate_metaclass_type()
def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None: class_def.info = info
# type=: type of the variable itself return info
var = Var(name=name, type=sym_type)
# var.info: type of the object variable is bound to
var.info = info
var._fullname = info.fullname + '.' + name
var.is_initialized_in_class = True
var.is_inferred = True
info.names[name] = SymbolTableNode(MDEF, var,
plugin_generated=True)
def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument], MypyType]:
prepared_arguments = []
for argument in method_node.arguments[1:]:
argument.type_annotation = AnyType(TypeOfAny.unannotated)
prepared_arguments.append(argument)
return_type = AnyType(TypeOfAny.unannotated)
return prepared_arguments, return_type
def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance,
new_method_name: str, method_node: FuncDef) -> None:
semanal_api = get_semanal_api(ctx)
if method_node.type is None:
if not semanal_api.final_iteration:
semanal_api.defer()
return
arguments, return_type = build_unannotated_method_args(method_node)
add_method(ctx,
new_method_name,
args=arguments,
return_type=return_type,
self_type=self_type)
return
method_type = method_node.type
if not isinstance(method_type, CallableType):
if not semanal_api.final_iteration:
semanal_api.defer()
return
arguments = []
bound_return_type = semanal_api.anal_type(method_type.ret_type,
allow_placeholder=True)
assert bound_return_type is not None
if isinstance(bound_return_type, PlaceholderNode):
return
for arg_name, arg_type, original_argument in zip(method_type.arg_names[1:],
method_type.arg_types[1:],
method_node.arguments[1:]):
bound_arg_type = semanal_api.anal_type(arg_type, allow_placeholder=True)
assert bound_arg_type is not None
if isinstance(bound_arg_type, PlaceholderNode):
return
var = Var(name=original_argument.variable.name,
type=arg_type)
var.line = original_argument.variable.line
var.column = original_argument.variable.column
argument = Argument(variable=var,
type_annotation=bound_arg_type,
initializer=original_argument.initializer,
kind=original_argument.kind)
argument.set_line(original_argument)
arguments.append(argument)
add_method(ctx,
new_method_name,
args=arguments,
return_type=bound_return_type,
self_type=self_type)

View File

@@ -0,0 +1,117 @@
from typing import Union, Tuple, List, Optional, NamedTuple, cast
from mypy.nodes import Argument, FuncDef, Var, TypeInfo
from mypy.plugin import DynamicClassDefContext, ClassDefContext
from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzer
from mypy.types import Instance, CallableType, AnyType, TypeOfAny, PlaceholderType
from mypy.types import Type as MypyType
class IncompleteDefnException(Exception):
def __init__(self, error_message: str = '') -> None:
super().__init__(error_message)
class BoundNameNotFound(IncompleteDefnException):
def __init__(self, fullname: str) -> None:
super().__init__(f'No {fullname!r} found')
def get_semanal_api(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> SemanticAnalyzer:
return cast(SemanticAnalyzer, ctx.api)
def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]:
metaclass_sym = info.names.get('Meta')
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
return metaclass_sym.node
return None
def prepare_unannotated_method_signature(method_node: FuncDef) -> Tuple[List[Argument], MypyType]:
prepared_arguments = []
for argument in method_node.arguments[1:]:
argument.type_annotation = AnyType(TypeOfAny.unannotated)
prepared_arguments.append(argument)
return_type = AnyType(TypeOfAny.unannotated)
return prepared_arguments, return_type
class SignatureTuple(NamedTuple):
arguments: Optional[List[Argument]]
return_type: Optional[MypyType]
cannot_be_bound: bool
def analyze_callable_signature(api: SemanticAnalyzer, method_node: FuncDef) -> SignatureTuple:
method_type = method_node.type
assert isinstance(method_type, CallableType)
arguments = []
unbound = False
for arg_name, arg_type, original_argument in zip(method_type.arg_names[1:],
method_type.arg_types[1:],
method_node.arguments[1:]):
arg_type = api.anal_type(arg_type, allow_placeholder=True)
if isinstance(arg_type, PlaceholderType):
unbound = True
var = Var(name=original_argument.variable.name,
type=arg_type)
var.set_line(original_argument.variable)
if isinstance(arg_type, PlaceholderType):
unbound = True
argument = Argument(variable=var,
type_annotation=arg_type,
initializer=original_argument.initializer,
kind=original_argument.kind)
argument.set_line(original_argument)
arguments.append(argument)
ret_type = api.anal_type(method_type.ret_type, allow_placeholder=True)
if isinstance(ret_type, PlaceholderType):
unbound = True
return SignatureTuple(arguments, ret_type, unbound)
def copy_method_or_incomplete_defn_exception(ctx: ClassDefContext,
self_type: Instance,
new_method_name: str,
method_node: FuncDef) -> None:
semanal_api = get_semanal_api(ctx)
if method_node.type is None:
if not semanal_api.final_iteration:
raise IncompleteDefnException(f'Unannotated method {method_node.fullname!r}')
arguments, return_type = prepare_unannotated_method_signature(method_node)
add_method(ctx,
new_method_name,
args=arguments,
return_type=return_type,
self_type=self_type)
return
assert isinstance(method_node.type, CallableType)
# copy global SymbolTableNode objects from original class to the current node, if not present
original_module = semanal_api.modules[method_node.info.module_name]
for name, sym in original_module.names.items():
if (not sym.plugin_generated
and name not in semanal_api.cur_mod_node.names):
semanal_api.add_imported_symbol(name, sym, context=semanal_api.cur_mod_node)
arguments, return_type, unbound = analyze_callable_signature(semanal_api, method_node)
assert len(arguments) + 1 == len(method_node.arguments)
if unbound:
raise IncompleteDefnException(f'Signature of method {method_node.fullname!r} is not ready')
if new_method_name in ctx.cls.info.names:
del ctx.cls.info.names[new_method_name]
add_method(ctx,
new_method_name,
args=arguments,
return_type=return_type,
self_type=self_type)

View File

@@ -9,6 +9,7 @@ from mypy.options import Options
from mypy.plugin import ( from mypy.plugin import (
AttributeContext, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext, Plugin, AttributeContext, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext, Plugin,
) )
from mypy.semanal import dummy_context
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
import mypy_django_plugin.transformers.orm_lookups import mypy_django_plugin.transformers.orm_lookups
@@ -30,7 +31,7 @@ def transform_model_class(ctx: ClassDefContext,
if sym is not None and isinstance(sym.node, TypeInfo): if sym is not None and isinstance(sym.node, TypeInfo):
helpers.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1 helpers.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1
else: else:
if not ctx.api.final_iteration: if not ctx.api.final_iteration and not ctx.api.deferred:
ctx.api.defer() ctx.api.defer()
return return
@@ -180,7 +181,7 @@ class NewSemanalDjangoPlugin(Plugin):
if info.has_base(fullnames.FIELD_FULLNAME): if info.has_base(fullnames.FIELD_FULLNAME):
return partial(fields.transform_into_proper_return_type, django_context=self.django_context) return partial(fields.transform_into_proper_return_type, django_context=self.django_context)
if helpers.is_model_subclass_info(info, self.django_context): if helpers.is_subclass_of_model(info, self.django_context):
return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context) return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)
return None return None

View File

@@ -9,13 +9,13 @@ from mypy.types import Type as MypyType
from mypy.types import TypeOfAny from mypy.types import TypeOfAny
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, helpers, chk_helpers
def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]: def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]:
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() outer_model_info = chk_helpers.get_typechecker_api(ctx).scope.active_class()
if (outer_model_info is None if (outer_model_info is None
or not helpers.is_model_subclass_info(outer_model_info, django_context)): or not helpers.is_subclass_of_model(outer_model_info, django_context)):
return None return None
field_name = None field_name = None
@@ -66,21 +66,21 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
# __get__/__set__ of ForeignKey of derived model # __get__/__set__ of ForeignKey of derived model
for model_cls in django_context.all_registered_model_classes: for model_cls in django_context.all_registered_model_classes:
if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract: if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract:
derived_model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls) derived_model_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), model_cls)
if derived_model_info is not None: if derived_model_info is not None:
fk_ref_type = Instance(derived_model_info, []) fk_ref_type = Instance(derived_model_info, [])
derived_fk_type = reparametrize_related_field_type(default_related_field_type, derived_fk_type = reparametrize_related_field_type(default_related_field_type,
set_type=fk_ref_type, get_type=fk_ref_type) set_type=fk_ref_type, get_type=fk_ref_type)
helpers.add_new_sym_for_info(derived_model_info, chk_helpers.add_new_sym_for_info(derived_model_info,
name=current_field.name, name=current_field.name,
sym_type=derived_fk_type) sym_type=derived_fk_type)
related_model = related_model_cls related_model = related_model_cls
related_model_to_set = related_model_cls related_model_to_set = related_model_cls
if related_model_to_set._meta.proxy_for_model is not None: if related_model_to_set._meta.proxy_for_model is not None:
related_model_to_set = related_model_to_set._meta.proxy_for_model related_model_to_set = related_model_to_set._meta.proxy_for_model
typechecker_api = helpers.get_typechecker_api(ctx) typechecker_api = chk_helpers.get_typechecker_api(ctx)
related_model_info = helpers.lookup_class_typeinfo(typechecker_api, related_model) related_model_info = helpers.lookup_class_typeinfo(typechecker_api, related_model)
if related_model_info is None: if related_model_info is None:
@@ -114,7 +114,7 @@ def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
default_return_type = cast(Instance, ctx.default_return_type) default_return_type = cast(Instance, ctx.default_return_type)
is_nullable = False is_nullable = False
null_expr = helpers.get_call_argument_by_name(ctx, 'null') null_expr = chk_helpers.get_call_argument_by_name(ctx, 'null')
if null_expr is not None: if null_expr is not None:
is_nullable = helpers.parse_bool(null_expr) or False is_nullable = helpers.parse_bool(null_expr) or False
@@ -122,10 +122,10 @@ def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
return helpers.reparametrize_instance(default_return_type, [set_type, get_type]) return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
def determine_type_of_array_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: def determine_type_of_array_field(ctx: FunctionContext) -> MypyType:
default_return_type = set_descriptor_types_for_field(ctx) default_return_type = set_descriptor_types_for_field(ctx)
base_field_arg_type = helpers.get_call_argument_type_by_name(ctx, 'base_field') base_field_arg_type = chk_helpers.get_call_argument_type_by_name(ctx, 'base_field')
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance): if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
return default_return_type return default_return_type
@@ -141,9 +141,9 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan
default_return_type = ctx.default_return_type default_return_type = ctx.default_return_type
assert isinstance(default_return_type, Instance) assert isinstance(default_return_type, Instance)
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() outer_model_info = chk_helpers.get_typechecker_api(ctx).scope.active_class()
if (outer_model_info is None if (outer_model_info is None
or not helpers.is_model_subclass_info(outer_model_info, django_context)): or not helpers.is_subclass_of_model(outer_model_info, django_context)):
return ctx.default_return_type return ctx.default_return_type
assert isinstance(outer_model_info, TypeInfo) assert isinstance(outer_model_info, TypeInfo)
@@ -152,6 +152,6 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan
return fill_descriptor_types_for_related_field(ctx, django_context) return fill_descriptor_types_for_related_field(ctx, django_context)
if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME): if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME):
return determine_type_of_array_field(ctx, django_context) return determine_type_of_array_field(ctx)
return set_descriptor_types_for_field(ctx) return set_descriptor_types_for_field(ctx)

View File

@@ -5,11 +5,11 @@ from mypy.types import CallableType, Instance, NoneTyp
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy.types import TypeType from mypy.types import TypeType
from mypy_django_plugin.lib import helpers from mypy_django_plugin.lib import sem_helpers, chk_helpers
def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None: def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None:
meta_node = helpers.get_nested_meta_node_for_current_class(ctx.cls.info) meta_node = sem_helpers.get_nested_meta_node_for_current_class(ctx.cls.info)
if meta_node is None: if meta_node is None:
if not ctx.api.final_iteration: if not ctx.api.final_iteration:
ctx.api.defer() ctx.api.defer()
@@ -28,7 +28,7 @@ def extract_proper_type_for_get_form(ctx: MethodContext) -> MypyType:
object_type = ctx.type object_type = ctx.type
assert isinstance(object_type, Instance) assert isinstance(object_type, Instance)
form_class_type = helpers.get_call_argument_type_by_name(ctx, 'form_class') form_class_type = chk_helpers.get_call_argument_type_by_name(ctx, 'form_class')
if form_class_type is None or isinstance(form_class_type, NoneTyp): if form_class_type is None or isinstance(form_class_type, NoneTyp):
form_class_type = get_specified_form_class(object_type) form_class_type = get_specified_form_class(object_type)

View File

@@ -6,7 +6,7 @@ from mypy.types import Instance
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import helpers from mypy_django_plugin.lib import chk_helpers
def get_actual_types(ctx: Union[MethodContext, FunctionContext], def get_actual_types(ctx: Union[MethodContext, FunctionContext],
@@ -32,7 +32,7 @@ def get_actual_types(ctx: Union[MethodContext, FunctionContext],
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext, def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
model_cls: Type[Model], method: str) -> MypyType: model_cls: Type[Model], method: str) -> MypyType:
typechecker_api = helpers.get_typechecker_api(ctx) typechecker_api = chk_helpers.get_typechecker_api(ctx)
expected_types = django_context.get_expected_types(typechecker_api, model_cls, method=method) expected_types = django_context.get_expected_types(typechecker_api, model_cls, method=method)
expected_keys = [key for key in expected_types.keys() if key != 'pk'] expected_keys = [key for key in expected_types.keys() if key != 'pk']
@@ -42,11 +42,11 @@ def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_co
model_cls.__name__), model_cls.__name__),
ctx.context) ctx.context)
continue continue
helpers.check_types_compatible(ctx, error_message = 'Incompatible type for "{}" of "{}"'.format(actual_name, model_cls.__name__)
expected_type=expected_types[actual_name], chk_helpers.check_types_compatible(ctx,
actual_type=actual_type, expected_type=expected_types[actual_name],
error_message='Incompatible type for "{}" of "{}"'.format(actual_name, actual_type=actual_type,
model_cls.__name__)) error_message=error_message)
return ctx.default_return_type return ctx.default_return_type

View File

@@ -1,77 +1,149 @@
from typing import Iterator, Tuple, Optional
from mypy.nodes import ( from mypy.nodes import (
GDEF, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, TypeInfo,
PlaceholderNode, SymbolTableNode, GDEF
) )
from mypy.plugin import ClassDefContext, DynamicClassDefContext from mypy.plugin import ClassDefContext, DynamicClassDefContext
from mypy.types import AnyType, Instance, TypeOfAny from mypy.types import AnyType, Instance, TypeOfAny
from mypy.typevars import fill_typevars
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, sem_helpers, helpers
def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefContext) -> None: def iter_all_custom_queryset_methods(derived_queryset_info: TypeInfo) -> Iterator[Tuple[str, FuncDef]]:
semanal_api = helpers.get_semanal_api(ctx) for base_queryset_info in derived_queryset_info.mro:
if base_queryset_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME:
break
for name, sym in base_queryset_info.names.items():
if isinstance(sym.node, FuncDef):
yield name, sym.node
def resolve_callee_manager_info_or_exception(ctx: DynamicClassDefContext) -> Optional[TypeInfo]:
callee = ctx.call.callee callee = ctx.call.callee
assert isinstance(callee, MemberExpr) assert isinstance(callee, MemberExpr)
assert isinstance(callee.expr, RefExpr) assert isinstance(callee.expr, RefExpr)
base_manager_info = callee.expr.node callee_manager_info = callee.expr.node
if base_manager_info is None: if (callee_manager_info is None
if not semanal_api.final_iteration: or isinstance(callee_manager_info, PlaceholderNode)):
semanal_api.defer() raise sem_helpers.IncompleteDefnException(f'Definition of base manager {callee_manager_info.fullname} '
return f'is incomplete.')
assert isinstance(base_manager_info, TypeInfo) assert isinstance(callee_manager_info, TypeInfo)
new_manager_info = semanal_api.basic_new_typeinfo(ctx.name, return callee_manager_info
basetype_or_fallback=Instance(base_manager_info,
[AnyType(TypeOfAny.unannotated)]))
new_manager_info.line = ctx.call.line
new_manager_info.defn.line = ctx.call.line
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
current_module = semanal_api.cur_mod_node
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info,
plugin_generated=True)
passed_queryset = ctx.call.args[0]
assert isinstance(passed_queryset, NameExpr)
derived_queryset_fullname = passed_queryset.fullname def resolve_passed_queryset_info_or_exception(ctx: DynamicClassDefContext) -> Optional[TypeInfo]:
assert derived_queryset_fullname is not None api = sem_helpers.get_semanal_api(ctx)
sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname) passed_queryset_name_expr = ctx.call.args[0]
assert sym is not None assert isinstance(passed_queryset_name_expr, NameExpr)
if sym.node is None:
if not semanal_api.final_iteration:
semanal_api.defer()
else:
# inherit from Any to prevent false-positives, if queryset class cannot be resolved
new_manager_info.fallback_to_any = True
return
derived_queryset_info = sym.node sym = api.lookup_qualified(passed_queryset_name_expr.name, ctx=ctx.call)
assert isinstance(derived_queryset_info, TypeInfo) if (sym is None
or sym.node is None
or isinstance(sym.node, PlaceholderNode)):
raise sem_helpers.BoundNameNotFound(passed_queryset_name_expr.fullname)
assert isinstance(sym.node, TypeInfo)
return sym.node
def resolve_django_manager_info_or_exception(ctx: DynamicClassDefContext) -> Optional[TypeInfo]:
api = sem_helpers.get_semanal_api(ctx)
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if (sym is None
or sym.node is None
or isinstance(sym.node, PlaceholderNode)):
raise sem_helpers.BoundNameNotFound(fullnames.MANAGER_CLASS_FULLNAME)
assert isinstance(sym.node, TypeInfo)
return sym.node
def new_manager_typeinfo(ctx: DynamicClassDefContext, callee_manager_info: TypeInfo) -> TypeInfo:
callee_manager_type = Instance(callee_manager_info, [AnyType(TypeOfAny.unannotated)])
api = sem_helpers.get_semanal_api(ctx)
new_manager_class_name = ctx.name
new_manager_info = helpers.new_typeinfo(new_manager_class_name,
bases=[callee_manager_type], module_name=api.cur_mod_id)
new_manager_info.set_line(ctx.call)
return new_manager_info
def record_new_manager_info_fullname_into_metadata(ctx: DynamicClassDefContext,
new_manager_fullname: str,
callee_manager_info: TypeInfo,
queryset_info: TypeInfo,
django_manager_info: TypeInfo) -> None:
if len(ctx.call.args) > 1: if len(ctx.call.args) > 1:
expr = ctx.call.args[1] expr = ctx.call.args[1]
assert isinstance(expr, StrExpr) assert isinstance(expr, StrExpr)
custom_manager_generated_name = expr.value custom_manager_generated_name = expr.value
else: else:
custom_manager_generated_name = base_manager_info.name + 'From' + derived_queryset_info.name custom_manager_generated_name = callee_manager_info.name + 'From' + queryset_info.name
custom_manager_generated_fullname = '.'.join(['django.db.models.manager', custom_manager_generated_name]) custom_manager_generated_fullname = 'django.db.models.manager' + '.' + custom_manager_generated_name
if 'from_queryset_managers' not in base_manager_info.metadata:
base_manager_info.metadata['from_queryset_managers'] = {} metadata = django_manager_info.metadata.setdefault('from_queryset_managers', {})
base_manager_info.metadata['from_queryset_managers'][custom_manager_generated_fullname] = new_manager_info.fullname metadata[custom_manager_generated_fullname] = new_manager_fullname
def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefContext) -> None:
semanal_api = sem_helpers.get_semanal_api(ctx)
try:
callee_manager_info = resolve_callee_manager_info_or_exception(ctx)
queryset_info = resolve_passed_queryset_info_or_exception(ctx)
django_manager_info = resolve_django_manager_info_or_exception(ctx)
except sem_helpers.IncompleteDefnException:
if not semanal_api.final_iteration:
semanal_api.defer()
return
else:
raise
new_manager_info = new_manager_typeinfo(ctx, callee_manager_info)
record_new_manager_info_fullname_into_metadata(ctx,
new_manager_info.fullname,
callee_manager_info,
queryset_info,
django_manager_info)
class_def_context = ClassDefContext(cls=new_manager_info.defn, class_def_context = ClassDefContext(cls=new_manager_info.defn,
reason=ctx.call, api=semanal_api) reason=ctx.call, api=semanal_api)
self_type = Instance(new_manager_info, []) self_type = fill_typevars(new_manager_info)
# we need to copy all methods in MRO before django.db.models.query.QuerySet # self_type = Instance(new_manager_info, [])
for class_mro_info in derived_queryset_info.mro:
if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME: try:
break for name, method_node in iter_all_custom_queryset_methods(queryset_info):
for name, sym in class_mro_info.names.items(): sem_helpers.copy_method_or_incomplete_defn_exception(class_def_context,
if isinstance(sym.node, FuncDef): self_type,
helpers.copy_method_to_another_class(class_def_context, new_method_name=name,
self_type, method_node=method_node)
new_method_name=name, except sem_helpers.IncompleteDefnException:
method_node=sym.node) if not semanal_api.final_iteration:
semanal_api.defer()
return
else:
raise
new_manager_sym = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
# context=None - forcibly replace old node
added = semanal_api.add_symbol_table_node(ctx.name, new_manager_sym, context=None)
if added:
# replace all references to the old manager Var everywhere
for _, module in semanal_api.modules.items():
if module.fullname != semanal_api.cur_mod_id:
for sym_name, sym in module.names.items():
if sym.fullname == new_manager_info.fullname:
module.names[sym_name] = new_manager_sym.copy()
# we need another iteration to process methods
if (not added
and not semanal_api.final_iteration):
semanal_api.defer()

View File

@@ -5,12 +5,12 @@ from mypy.types import Type as MypyType
from mypy.types import TypeOfAny from mypy.types import TypeOfAny
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import helpers from mypy_django_plugin.lib import helpers, chk_helpers
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType: def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
field_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx), api = chk_helpers.get_typechecker_api(ctx)
field_fullname) field_info = helpers.lookup_fully_qualified_typeinfo(api, field_fullname)
if field_info is None: if field_info is None:
return AnyType(TypeOfAny.unannotated) return AnyType(TypeOfAny.unannotated)
return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)]) return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)])
@@ -32,7 +32,7 @@ def return_proper_field_type_from_get_field(ctx: MethodContext, django_context:
if model_cls is None: if model_cls is None:
return ctx.default_return_type return ctx.default_return_type
field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name') field_name_expr = chk_helpers.get_call_argument_by_name(ctx, 'field_name')
if field_name_expr is None: if field_name_expr is None:
return ctx.default_return_type return ctx.default_return_type

View File

@@ -1,21 +1,21 @@
from typing import Dict, List, Optional, Type, cast from typing import List, Optional, Type, cast
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.fields import DateField, DateTimeField from django.db.models.fields.related import ForeignKey, OneToOneField
from django.db.models.fields.related import ForeignKey
from django.db.models.fields.reverse_related import ( from django.db.models.fields.reverse_related import (
ManyToManyRel, ManyToOneRel, OneToOneRel, ManyToManyRel, ManyToOneRel, OneToOneRel,
) )
from mypy.nodes import ARG_STAR2, Argument, Context, FuncDef, TypeInfo, Var from mypy.nodes import ARG_STAR2, Argument, FuncDef, TypeInfo, Var, SymbolTableNode, MDEF, GDEF
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.plugins import common from mypy.plugins import common
from mypy.semanal import SemanticAnalyzer from mypy.semanal import SemanticAnalyzer, dummy_context
from mypy.types import AnyType, Instance from mypy.types import AnyType, Instance
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy.types import TypeOfAny from mypy.types import TypeOfAny
from django.db.models.fields import DateField, DateTimeField
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, helpers, sem_helpers
from mypy_django_plugin.transformers import fields from mypy_django_plugin.transformers import fields
from mypy_django_plugin.transformers.fields import get_field_descriptor_types from mypy_django_plugin.transformers.fields import get_field_descriptor_types
@@ -35,7 +35,7 @@ class ModelClassInitializer:
def lookup_typeinfo_or_incomplete_defn_error(self, fullname: str) -> TypeInfo: def lookup_typeinfo_or_incomplete_defn_error(self, fullname: str) -> TypeInfo:
info = self.lookup_typeinfo(fullname) info = self.lookup_typeinfo(fullname)
if info is None: if info is None:
raise helpers.IncompleteDefnException(f'No {fullname!r} found') raise sem_helpers.IncompleteDefnException(f'No {fullname!r} found')
return info return info
def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInfo: def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInfo:
@@ -43,26 +43,74 @@ class ModelClassInitializer:
field_info = self.lookup_typeinfo_or_incomplete_defn_error(fullname) field_info = self.lookup_typeinfo_or_incomplete_defn_error(fullname)
return field_info return field_info
def create_new_var(self, name: str, typ: MypyType) -> Var: def model_class_has_attribute_defined(self, name: str, traverse_mro: bool = True) -> bool:
# type=: type of the variable itself if not traverse_mro:
var = Var(name=name, type=typ) sym = self.model_classdef.info.names.get(name)
# var.info: type of the object variable is bound to else:
sym = self.model_classdef.info.get(name)
return sym is not None
def resolve_manager_fullname(self, manager_fullname: str) -> str:
base_manager_info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME)
if (base_manager_info is None
or 'from_queryset_managers' not in base_manager_info.metadata):
return manager_fullname
metadata = base_manager_info.metadata['from_queryset_managers']
return metadata.get(manager_fullname, manager_fullname)
def add_new_node_to_model_class(self, name: str, typ: MypyType,
force_replace_existing: bool = False) -> None:
if not force_replace_existing and name in self.model_classdef.info.names:
raise ValueError(f'Member {name!r} already defined at model {self.model_classdef.info.fullname!r}.')
var = Var(name, type=typ)
# TypeInfo of the object variable is bound to
var.info = self.model_classdef.info var.info = self.model_classdef.info
var._fullname = self.model_classdef.info.fullname + '.' + name var._fullname = self.api.qualified_name(name)
var.is_initialized_in_class = True var.is_initialized_in_class = True
var.is_inferred = True
return var
def add_new_node_to_model_class(self, name: str, typ: MypyType) -> None: sym = SymbolTableNode(MDEF, var, plugin_generated=True)
helpers.add_new_sym_for_info(self.model_classdef.info, context = dummy_context()
name=name, if force_replace_existing:
sym_type=typ) context = None
self.api.add_symbol_table_node(name, sym, context=context)
def add_new_class_for_current_module(self, name: str, bases: List[Instance]) -> TypeInfo: def add_new_class_for_current_module(self, name: str, bases: List[Instance],
current_module = self.api.modules[self.model_classdef.info.module_name] force_replace_existing: bool = False) -> Optional[TypeInfo]:
new_class_info = helpers.add_new_class_for_module(current_module, current_module = self.api.cur_mod_node
name=name, bases=bases) if not force_replace_existing and name in current_module:
return new_class_info raise ValueError(f'Class {name!r} already defined for module {current_module.fullname!r}')
new_typeinfo = helpers.new_typeinfo(name,
bases=bases,
module_name=current_module.fullname)
# sym = SymbolTableNode(GDEF, new_typeinfo,
# plugin_generated=True)
# context = dummy_context()
# if force_replace_existing:
# context = None
if name in current_module.names:
del current_module.names[name]
current_module.names[name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)
# current_module.defs.append(new_typeinfo.defn)
# self.api.cur_mod_node.
# self.api.leave_class()
# added = self.api.add_symbol_table_node(name, sym, context=context)
# self.api.enter_class(self.model_classdef.info)
#
# self.api.cur_mod_node.defs.append(new_typeinfo.defn)
# if not added and force_replace_existing:
# return None
return new_typeinfo
# current_module = self.api.modules[self.model_classdef.info.module_name]
# context =
# new_class_info = helpers.add_new_class_for_module(current_module,
# name=name, bases=bases)
# return new_class_info
def run(self) -> None: def run(self) -> None:
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname) model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname)
@@ -88,58 +136,90 @@ class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):
""" """
def run(self) -> None: def run(self) -> None:
meta_node = helpers.get_nested_meta_node_for_current_class(self.model_classdef.info) meta_node = sem_helpers.get_nested_meta_node_for_current_class(self.model_classdef.info)
if meta_node is None: if meta_node is None:
return None return None
meta_node.fallback_to_any = True meta_node.fallback_to_any = True
class AddDefaultPrimaryKey(ModelClassInitializer): class AddDefaultPrimaryKey(ModelClassInitializer):
"""
Adds default primary key to models which does not define their own.
```
class User(models.Model):
name = models.TextField()
```
"""
def run_with_model_cls(self, model_cls: Type[Model]) -> None: def run_with_model_cls(self, model_cls: Type[Model]) -> None:
auto_field = model_cls._meta.auto_field auto_field = model_cls._meta.auto_field
if auto_field and not self.model_classdef.info.has_readable_member(auto_field.attname): if auto_field is None:
# autogenerated field return
auto_field_fullname = helpers.get_class_fullname(auto_field.__class__)
auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_fullname)
set_type, get_type = fields.get_field_descriptor_types(auto_field_info, is_nullable=False) primary_key_attrname = auto_field.attname
self.add_new_node_to_model_class(auto_field.attname, Instance(auto_field_info, if self.model_class_has_attribute_defined(primary_key_attrname):
[set_type, get_type])) return
auto_field_class_fullname = helpers.get_class_fullname(auto_field.__class__)
auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_class_fullname)
set_type, get_type = fields.get_field_descriptor_types(auto_field_info, is_nullable=False)
self.add_new_node_to_model_class(primary_key_attrname, Instance(auto_field_info,
[set_type, get_type]))
class AddRelatedModelsId(ModelClassInitializer): class AddRelatedModelsId(ModelClassInitializer):
"""
Adds `FIELDNAME_id` attributes to models.
```
class User(models.Model):
pass
class Blog(models.Model):
user = models.ForeignKey(User)
```
`user_id` will be added to `Blog`.
"""
def run_with_model_cls(self, model_cls: Type[Model]) -> None: def run_with_model_cls(self, model_cls: Type[Model]) -> None:
for field in model_cls._meta.get_fields(): for field in model_cls._meta.get_fields():
if isinstance(field, ForeignKey): if not isinstance(field, (OneToOneField, ForeignKey)):
related_model_cls = self.django_context.get_field_related_model_cls(field) continue
if related_model_cls is None: related_id_attr_name = field.attname
error_context: Context = self.ctx.cls if self.model_class_has_attribute_defined(related_id_attr_name):
field_sym = self.ctx.cls.info.get(field.name) continue
if field_sym is not None and field_sym.node is not None: # if self.get_model_class_attr(related_id_attr_name) is not None:
error_context = field_sym.node # continue
self.api.fail(f'Cannot find model {field.related_model!r} '
f'referenced in field {field.name!r} ', related_model_cls = self.django_context.get_field_related_model_cls(field)
ctx=error_context) if related_model_cls is None:
self.add_new_node_to_model_class(field.attname, error_context = self.ctx.cls
AnyType(TypeOfAny.explicit)) field_sym = self.ctx.cls.info.get(field.name)
if field_sym is not None and field_sym.node is not None:
error_context = field_sym.node
self.api.fail(f'Cannot find model {field.related_model!r} '
f'referenced in field {field.name!r} ',
ctx=error_context)
self.add_new_node_to_model_class(related_id_attr_name,
AnyType(TypeOfAny.explicit))
continue
if related_model_cls._meta.abstract:
continue
rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
try:
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__)
except sem_helpers.IncompleteDefnException as exc:
if not self.api.final_iteration:
raise exc
else:
continue continue
if related_model_cls._meta.abstract: is_nullable = self.django_context.get_field_nullability(field, None)
continue set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
self.add_new_node_to_model_class(related_id_attr_name,
rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls) Instance(field_info, [set_type, get_type]))
try:
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__)
except helpers.IncompleteDefnException as exc:
if not self.api.final_iteration:
raise exc
else:
continue
is_nullable = self.django_context.get_field_nullability(field, None)
set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
self.add_new_node_to_model_class(field.attname,
Instance(field_info, [set_type, get_type]))
class AddManagers(ModelClassInitializer): class AddManagers(ModelClassInitializer):
@@ -152,25 +232,15 @@ class AddManagers(ModelClassInitializer):
def is_any_parametrized_manager(self, typ: Instance) -> bool: def is_any_parametrized_manager(self, typ: Instance) -> bool:
return typ.type.fullname in fullnames.MANAGER_CLASSES and isinstance(typ.args[0], AnyType) return typ.type.fullname in fullnames.MANAGER_CLASSES and isinstance(typ.args[0], AnyType)
def get_generated_manager_mappings(self, base_manager_fullname: str) -> Dict[str, str]:
base_manager_info = self.lookup_typeinfo(base_manager_fullname)
if (base_manager_info is None
or 'from_queryset_managers' not in base_manager_info.metadata):
return {}
return base_manager_info.metadata['from_queryset_managers']
def create_new_model_parametrized_manager(self, name: str, base_manager_info: TypeInfo) -> Instance: def create_new_model_parametrized_manager(self, name: str, base_manager_info: TypeInfo) -> Instance:
bases = [] bases = []
for original_base in base_manager_info.bases: for original_base in base_manager_info.bases:
if self.is_any_parametrized_manager(original_base): if self.is_any_parametrized_manager(original_base):
if original_base.type is None:
raise helpers.IncompleteDefnException()
original_base = helpers.reparametrize_instance(original_base, original_base = helpers.reparametrize_instance(original_base,
[Instance(self.model_classdef.info, [])]) [Instance(self.model_classdef.info, [])])
bases.append(original_base) bases.append(original_base)
new_manager_info = self.add_new_class_for_current_module(name, bases) new_manager_info = self.add_new_class_for_current_module(name, bases, force_replace_existing=True)
# copy fields to a new manager # copy fields to a new manager
new_cls_def_context = ClassDefContext(cls=new_manager_info.defn, new_cls_def_context = ClassDefContext(cls=new_manager_info.defn,
reason=self.ctx.reason, reason=self.ctx.reason,
@@ -178,12 +248,15 @@ class AddManagers(ModelClassInitializer):
custom_manager_type = Instance(new_manager_info, [Instance(self.model_classdef.info, [])]) custom_manager_type = Instance(new_manager_info, [Instance(self.model_classdef.info, [])])
for name, sym in base_manager_info.names.items(): for name, sym in base_manager_info.names.items():
if name in new_manager_info.names:
raise ValueError(f'Name {name!r} already exists on newly-created {new_manager_info.fullname!r} class.')
# replace self type with new class, if copying method # replace self type with new class, if copying method
if isinstance(sym.node, FuncDef): if isinstance(sym.node, FuncDef):
helpers.copy_method_to_another_class(new_cls_def_context, sem_helpers.copy_method_or_incomplete_defn_exception(new_cls_def_context,
self_type=custom_manager_type, self_type=custom_manager_type,
new_method_name=name, new_method_name=name,
method_node=sym.node) method_node=sym.node)
continue continue
new_sym = sym.copy() new_sym = sym.copy()
@@ -192,32 +265,36 @@ class AddManagers(ModelClassInitializer):
new_var.info = new_manager_info new_var.info = new_manager_info
new_var._fullname = new_manager_info.fullname + '.' + name new_var._fullname = new_manager_info.fullname + '.' + name
new_sym.node = new_var new_sym.node = new_var
new_manager_info.names[name] = new_sym new_manager_info.names[name] = new_sym
return custom_manager_type return custom_manager_type
def run_with_model_cls(self, model_cls: Type[Model]) -> None: def run_with_model_cls(self, model_cls: Type[Model]) -> None:
for manager_name, manager in model_cls._meta.managers_map.items(): for manager_name, manager in model_cls._meta.managers_map.items():
manager_class_name = manager.__class__.__name__ if self.model_class_has_attribute_defined(manager_name, traverse_mro=False):
manager_fullname = helpers.get_class_fullname(manager.__class__) sym = self.model_classdef.info.names.get(manager_name)
try: assert sym is not None
manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname)
except helpers.IncompleteDefnException as exc: if (sym.type is not None
if not self.api.final_iteration: and isinstance(sym.type, Instance)
raise exc and sym.type.type.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME)
else: and not self.has_any_parametrized_manager_as_base(sym.type.type)):
base_manager_fullname = helpers.get_class_fullname(manager.__class__.__bases__[0]) # already defined and parametrized properly
generated_managers = self.get_generated_manager_mappings(base_manager_fullname) continue
if manager_fullname not in generated_managers:
# not a generated manager, continue with the loop if getattr(manager, '_built_with_as_manager', False):
continue # as_manager is not supported yet
real_manager_fullname = generated_managers[manager_fullname] if not self.model_class_has_attribute_defined(manager_name, traverse_mro=True):
manager_info = self.lookup_typeinfo(real_manager_fullname) # type: ignore self.add_new_node_to_model_class(manager_name, AnyType(TypeOfAny.explicit))
if manager_info is None: continue
continue
manager_class_name = real_manager_fullname.rsplit('.', maxsplit=1)[1] manager_fullname = self.resolve_manager_fullname(helpers.get_class_fullname(manager.__class__))
manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname)
manager_class_name = manager_fullname.rsplit('.', maxsplit=1)[1]
if manager_name not in self.model_classdef.info.names: if manager_name not in self.model_classdef.info.names:
# manager not yet defined, just add models.Manager[ModelName]
manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])]) manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])])
self.add_new_node_to_model_class(manager_name, manager_type) self.add_new_node_to_model_class(manager_name, manager_type)
else: else:
@@ -226,56 +303,67 @@ class AddManagers(ModelClassInitializer):
continue continue
custom_model_manager_name = manager.model.__name__ + '_' + manager_class_name custom_model_manager_name = manager.model.__name__ + '_' + manager_class_name
try: custom_manager_type = self.create_new_model_parametrized_manager(custom_model_manager_name,
custom_manager_type = self.create_new_model_parametrized_manager(custom_model_manager_name, base_manager_info=manager_info)
base_manager_info=manager_info)
except helpers.IncompleteDefnException:
continue
self.add_new_node_to_model_class(manager_name, custom_manager_type) self.add_new_node_to_model_class(manager_name, custom_manager_type,
force_replace_existing=True)
class AddDefaultManagerAttribute(ModelClassInitializer): class AddDefaultManagerAttribute(ModelClassInitializer):
def run_with_model_cls(self, model_cls: Type[Model]) -> None: def run_with_model_cls(self, model_cls: Type[Model]) -> None:
# add _default_manager if self.model_class_has_attribute_defined('_default_manager', traverse_mro=False):
if '_default_manager' not in self.model_classdef.info.names: return
default_manager_fullname = helpers.get_class_fullname(model_cls._meta.default_manager.__class__) if model_cls._meta.default_manager is None:
default_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(default_manager_fullname) return
default_manager = Instance(default_manager_info, [Instance(self.model_classdef.info, [])]) if getattr(model_cls._meta.default_manager, '_built_with_as_manager', False):
self.add_new_node_to_model_class('_default_manager', default_manager) self.add_new_node_to_model_class('_default_manager',
AnyType(TypeOfAny.explicit))
return
default_manager_fullname = helpers.get_class_fullname(model_cls._meta.default_manager.__class__)
resolved_default_manager_fullname = self.resolve_manager_fullname(default_manager_fullname)
default_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(resolved_default_manager_fullname)
default_manager = Instance(default_manager_info, [Instance(self.model_classdef.info, [])])
self.add_new_node_to_model_class('_default_manager', default_manager)
class AddRelatedManagers(ModelClassInitializer): class AddRelatedManagers(ModelClassInitializer):
def run_with_model_cls(self, model_cls: Type[Model]) -> None: def run_with_model_cls(self, model_cls: Type[Model]) -> None:
# add related managers # add related managers
for relation in self.django_context.get_model_relations(model_cls): for relation in self.django_context.get_model_relations(model_cls):
attname = relation.get_accessor_name() related_manager_attr_name = relation.get_accessor_name()
if attname is None: if related_manager_attr_name is None:
# no reverse accessor # no reverse accessor
continue continue
if self.model_class_has_attribute_defined(related_manager_attr_name, traverse_mro=False):
continue
related_model_cls = self.django_context.get_field_related_model_cls(relation) related_model_cls = self.django_context.get_field_related_model_cls(relation)
if related_model_cls is None: if related_model_cls is None:
continue continue
try: try:
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls) related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls)
except helpers.IncompleteDefnException as exc: except sem_helpers.IncompleteDefnException as exc:
if not self.api.final_iteration: if not self.api.final_iteration:
raise exc raise exc
else: else:
continue continue
if isinstance(relation, OneToOneRel): if isinstance(relation, OneToOneRel):
self.add_new_node_to_model_class(attname, Instance(related_model_info, [])) self.add_new_node_to_model_class(related_manager_attr_name, Instance(related_model_info, []))
continue continue
if isinstance(relation, (ManyToOneRel, ManyToManyRel)): if isinstance(relation, (ManyToOneRel, ManyToManyRel)):
try: try:
related_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.RELATED_MANAGER_CLASS) # noqa: E501 related_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(
fullnames.RELATED_MANAGER_CLASS) # noqa: E501
if 'objects' not in related_model_info.names: if 'objects' not in related_model_info.names:
raise helpers.IncompleteDefnException() raise sem_helpers.IncompleteDefnException()
except helpers.IncompleteDefnException as exc: except sem_helpers.IncompleteDefnException as exc:
if not self.api.final_iteration: if not self.api.final_iteration:
raise exc raise exc
else: else:
@@ -288,14 +376,20 @@ class AddRelatedManagers(ModelClassInitializer):
if (default_manager_type is None if (default_manager_type is None
or not isinstance(default_manager_type, Instance) or not isinstance(default_manager_type, Instance)
or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME): or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME):
self.add_new_node_to_model_class(attname, parametrized_related_manager_type) self.add_new_node_to_model_class(related_manager_attr_name, parametrized_related_manager_type)
continue continue
name = related_model_cls.__name__ + '_' + 'RelatedManager' name = related_model_cls.__name__ + '_' + 'RelatedManager'
bases = [parametrized_related_manager_type, default_manager_type] bases = [parametrized_related_manager_type, default_manager_type]
new_related_manager_info = self.add_new_class_for_current_module(name, bases) new_related_manager_info = self.add_new_class_for_current_module(name, bases,
force_replace_existing=True)
if new_related_manager_info is None:
# wasn't added for some reason, defer
if not self.api.final_iteration:
self.api.defer()
continue
self.add_new_node_to_model_class(attname, Instance(new_related_manager_info, [])) self.add_new_node_to_model_class(related_manager_attr_name, Instance(new_related_manager_info, []))
class AddExtraFieldMethods(ModelClassInitializer): class AddExtraFieldMethods(ModelClassInitializer):
@@ -355,6 +449,8 @@ def process_model_class(ctx: ClassDefContext,
for initializer_cls in initializers: for initializer_cls in initializers:
try: try:
initializer_cls(ctx, django_context).run() initializer_cls(ctx, django_context).run()
except helpers.IncompleteDefnException: except sem_helpers.IncompleteDefnException as exc:
if not ctx.api.final_iteration: if not ctx.api.final_iteration:
ctx.api.defer() ctx.api.defer()
continue
raise exc

View File

@@ -4,7 +4,7 @@ from mypy.types import Type as MypyType
from mypy.types import TypeOfAny from mypy.types import TypeOfAny
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, helpers, chk_helpers
def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType: def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
@@ -35,10 +35,10 @@ def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext)
fullnames.QUERYSET_CLASS_FULLNAME))): fullnames.QUERYSET_CLASS_FULLNAME))):
return ctx.default_return_type return ctx.default_return_type
helpers.check_types_compatible(ctx, chk_helpers.check_types_compatible(ctx,
expected_type=lookup_type, expected_type=lookup_type,
actual_type=provided_type, actual_type=provided_type,
error_message=f'Incompatible type for lookup {lookup_kwarg!r}:') error_message=f'Incompatible type for lookup {lookup_kwarg!r}:')
return ctx.default_return_type return ctx.default_return_type

View File

@@ -14,7 +14,7 @@ from mypy.types import TypeOfAny
from mypy_django_plugin.django.context import ( from mypy_django_plugin.django.context import (
DjangoContext, LookupsAreUnsupported, DjangoContext, LookupsAreUnsupported,
) )
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, helpers, chk_helpers
def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]: def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]:
@@ -30,7 +30,7 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
default_return_type = ctx.default_return_type default_return_type = ctx.default_return_type
assert isinstance(default_return_type, Instance) assert isinstance(default_return_type, Instance)
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() outer_model_info = chk_helpers.get_typechecker_api(ctx).scope.active_class()
if (outer_model_info is None if (outer_model_info is None
or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)): or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)):
return default_return_type return default_return_type
@@ -55,7 +55,7 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext
return AnyType(TypeOfAny.from_error) return AnyType(TypeOfAny.from_error)
lookup_field = django_context.get_primary_key_field(related_model_cls) lookup_field = django_context.get_primary_key_field(related_model_cls)
field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx), field_get_type = django_context.get_field_get_type(chk_helpers.get_typechecker_api(ctx),
lookup_field, method=method) lookup_field, method=method)
return field_get_type return field_get_type
@@ -66,7 +66,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
if field_lookups is None: if field_lookups is None:
return AnyType(TypeOfAny.from_error) return AnyType(TypeOfAny.from_error)
typechecker_api = helpers.get_typechecker_api(ctx) typechecker_api = chk_helpers.get_typechecker_api(ctx)
if len(field_lookups) == 0: if len(field_lookups) == 0:
if flat: if flat:
primary_key_field = django_context.get_primary_key_field(model_cls) primary_key_field = django_context.get_primary_key_field(model_cls)
@@ -80,7 +80,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
column_type = django_context.get_field_get_type(typechecker_api, field, column_type = django_context.get_field_get_type(typechecker_api, field,
method='values_list') method='values_list')
column_types[field.attname] = column_type column_types[field.attname] = column_type
return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) return chk_helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
else: else:
# flat=False, named=False, all fields # flat=False, named=False, all fields
field_lookups = [] field_lookups = []
@@ -103,9 +103,9 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
assert len(column_types) == 1 assert len(column_types) == 1
row_type = next(iter(column_types.values())) row_type = next(iter(column_types.values()))
elif named: elif named:
row_type = helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) row_type = chk_helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
else: else:
row_type = helpers.make_tuple(typechecker_api, list(column_types.values())) row_type = chk_helpers.make_tuple(typechecker_api, list(column_types.values()))
return row_type return row_type
@@ -123,13 +123,13 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
if model_cls is None: if model_cls is None:
return ctx.default_return_type return ctx.default_return_type
flat_expr = helpers.get_call_argument_by_name(ctx, 'flat') flat_expr = chk_helpers.get_call_argument_by_name(ctx, 'flat')
if flat_expr is not None and isinstance(flat_expr, NameExpr): if flat_expr is not None and isinstance(flat_expr, NameExpr):
flat = helpers.parse_bool(flat_expr) flat = helpers.parse_bool(flat_expr)
else: else:
flat = False flat = False
named_expr = helpers.get_call_argument_by_name(ctx, 'named') named_expr = chk_helpers.get_call_argument_by_name(ctx, 'named')
if named_expr is not None and isinstance(named_expr, NameExpr): if named_expr is not None and isinstance(named_expr, NameExpr):
named = helpers.parse_bool(named_expr) named = helpers.parse_bool(named_expr)
else: else:
@@ -188,5 +188,5 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
column_types[field_lookup] = field_lookup_type column_types[field_lookup] = field_lookup_type
row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys())) row_type = chk_helpers.make_oneoff_typeddict(ctx.api, column_types, set(column_types.keys()))
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])

View File

@@ -3,13 +3,13 @@ from mypy.types import Instance
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import helpers from mypy_django_plugin.lib import helpers, chk_helpers
def set_auth_user_model_as_type_for_request_user(ctx: AttributeContext, django_context: DjangoContext) -> MypyType: def set_auth_user_model_as_type_for_request_user(ctx: AttributeContext, django_context: DjangoContext) -> MypyType:
auth_user_model = django_context.settings.AUTH_USER_MODEL auth_user_model = django_context.settings.AUTH_USER_MODEL
model_cls = django_context.apps_registry.get_model(auth_user_model) model_cls = django_context.apps_registry.get_model(auth_user_model)
model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls) model_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), model_cls)
if model_info is None: if model_info is None:
return ctx.default_attr_type return ctx.default_attr_type

View File

@@ -5,7 +5,7 @@ from mypy.types import Type as MypyType
from mypy.types import TypeOfAny, TypeType from mypy.types import TypeOfAny, TypeType
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import helpers from mypy_django_plugin.lib import helpers, chk_helpers
def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
@@ -13,7 +13,7 @@ def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) ->
model_cls = django_context.apps_registry.get_model(auth_user_model) model_cls = django_context.apps_registry.get_model(auth_user_model)
model_cls_fullname = helpers.get_class_fullname(model_cls) model_cls_fullname = helpers.get_class_fullname(model_cls)
model_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx), model_info = helpers.lookup_fully_qualified_typeinfo(chk_helpers.get_typechecker_api(ctx),
model_cls_fullname) model_cls_fullname)
if model_info is None: if model_info is None:
return AnyType(TypeOfAny.unannotated) return AnyType(TypeOfAny.unannotated)
@@ -28,7 +28,7 @@ def get_type_of_settings_attribute(ctx: AttributeContext, django_context: Django
ctx.api.fail(f"'Settings' object has no attribute {setting_name!r}", ctx.context) ctx.api.fail(f"'Settings' object has no attribute {setting_name!r}", ctx.context)
return ctx.default_attr_type return ctx.default_attr_type
typechecker_api = helpers.get_typechecker_api(ctx) typechecker_api = chk_helpers.get_typechecker_api(ctx)
# first look for the setting in the project settings file, then global settings # first look for the setting in the project settings file, then global settings
settings_module = typechecker_api.modules.get(django_context.django_settings_module) settings_module = typechecker_api.modules.get(django_context.django_settings_module)

View File

@@ -3,6 +3,7 @@
from myapp.models import MyModel from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]' reveal_type(MyModel().objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]'
reveal_type(MyModel().objects.get()) # N: Revealed type is 'myapp.models.MyModel*' reveal_type(MyModel().objects.get()) # N: Revealed type is 'myapp.models.MyModel*'
reveal_type(MyModel().objects.queryset_method) # N: Revealed type is 'def () -> builtins.str'
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is 'builtins.str' reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is 'builtins.str'
installed_apps: installed_apps:
- myapp - myapp
@@ -178,4 +179,57 @@
from django.db import models from django.db import models
class BaseQuerySet(models.QuerySet): class BaseQuerySet(models.QuerySet):
def base_queryset_method(self, param: Union[int, str]) -> NoReturn: def base_queryset_method(self, param: Union[int, str]) -> NoReturn:
raise ValueError raise ValueError
- case: from_queryset_with_inherited_manager_and_fk_to_auth_contrib
disable_cache: true
main: |
from myapp.base_queryset import BaseQuerySet
reveal_type(BaseQuerySet().base_queryset_method) # N: Revealed type is 'def (param: builtins.dict[builtins.str, Union[builtins.int, builtins.str]]) -> Union[builtins.int, builtins.str]'
from django.contrib.auth.models import Permission
reveal_type(Permission().another_models) # N: Revealed type is 'django.db.models.manager.RelatedManager[myapp.models.AnotherModelInProjectWithContribAuthM2M]'
from myapp.managers import NewManager
reveal_type(NewManager()) # N: Revealed type is 'myapp.managers.NewManager'
reveal_type(NewManager().base_queryset_method) # N: Revealed type is 'def (param: builtins.dict[builtins.str, Union[builtins.int, builtins.str]]) -> Union[builtins.int, builtins.str]'
from myapp.models import MyModel
reveal_type(MyModel().objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]'
reveal_type(MyModel().objects.get()) # N: Revealed type is 'myapp.models.MyModel*'
reveal_type(MyModel().objects.base_queryset_method) # N: Revealed type is 'def (param: builtins.dict[builtins.str, Union[builtins.int, builtins.str]]) -> Union[builtins.int, builtins.str]'
installed_apps:
- myapp
- django.contrib.auth
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
from myapp.managers import NewManager
from django.contrib.auth.models import Permission
class MyModel(models.Model):
objects = NewManager()
class AnotherModelInProjectWithContribAuthM2M(models.Model):
permissions = models.ForeignKey(
Permission,
on_delete=models.PROTECT,
related_name='another_models'
)
- path: myapp/managers.py
content: |
from django.db import models
from myapp.base_queryset import BaseQuerySet
class ModelQuerySet(BaseQuerySet):
pass
NewManager = models.Manager.from_queryset(ModelQuerySet)
- path: myapp/base_queryset.py
content: |
from typing import Union, Dict
from django.db import models
class BaseQuerySet(models.QuerySet):
def base_queryset_method(self, param: Dict[str, Union[int, str]]) -> Union[int, str]:
return param["hello"]

View File

@@ -307,15 +307,15 @@
- case: custom_manager_returns_proper_model_types - case: custom_manager_returns_proper_model_types
main: | main: |
from myapp.models import User from myapp.models import User
reveal_type(User.objects) # N: Revealed type is 'myapp.models.User_MyManager2[myapp.models.User]' reveal_type(User.objects) # N: Revealed type is 'myapp.models.User_MyManager[myapp.models.User]'
reveal_type(User.objects.select_related()) # N: Revealed type is 'myapp.models.User_MyManager2[myapp.models.User]' reveal_type(User.objects.select_related()) # N: Revealed type is 'myapp.models.User_MyManager[myapp.models.User]'
reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*' reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*'
reveal_type(User.objects.get_instance()) # N: Revealed type is 'builtins.int' reveal_type(User.objects.get_instance()) # N: Revealed type is 'builtins.int'
reveal_type(User.objects.get_instance_untyped('hello')) # N: Revealed type is 'Any' reveal_type(User.objects.get_instance_untyped('hello')) # N: Revealed type is 'Any'
from myapp.models import ChildUser from myapp.models import ChildUser
reveal_type(ChildUser.objects) # N: Revealed type is 'myapp.models.ChildUser_MyManager2[myapp.models.ChildUser]' reveal_type(ChildUser.objects) # N: Revealed type is 'myapp.models.ChildUser_MyManager[myapp.models.ChildUser]'
reveal_type(ChildUser.objects.select_related()) # N: Revealed type is 'myapp.models.ChildUser_MyManager2[myapp.models.ChildUser]' reveal_type(ChildUser.objects.select_related()) # N: Revealed type is 'myapp.models.ChildUser_MyManager[myapp.models.ChildUser]'
reveal_type(ChildUser.objects.get()) # N: Revealed type is 'myapp.models.ChildUser*' reveal_type(ChildUser.objects.get()) # N: Revealed type is 'myapp.models.ChildUser*'
reveal_type(ChildUser.objects.get_instance()) # N: Revealed type is 'builtins.int' reveal_type(ChildUser.objects.get_instance()) # N: Revealed type is 'builtins.int'
reveal_type(ChildUser.objects.get_instance_untyped('hello')) # N: Revealed type is 'Any' reveal_type(ChildUser.objects.get_instance_untyped('hello')) # N: Revealed type is 'Any'