mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-08 21:14:49 +08:00
refactor, fix method copying
This commit is contained in:
@@ -21,7 +21,7 @@ from mypy.types import AnyType, Instance
|
|||||||
from mypy.types import Type as MypyType
|
from mypy.types import Type as MypyType
|
||||||
from mypy.types import TypeOfAny, UnionType
|
from mypy.types import TypeOfAny, UnionType
|
||||||
|
|
||||||
from mypy_django_plugin.lib import fullnames, helpers
|
from mypy_django_plugin.lib import fullnames, helpers, chk_helpers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from django.contrib.postgres.fields import ArrayField
|
from django.contrib.postgres.fields import ArrayField
|
||||||
@@ -356,11 +356,11 @@ class DjangoContext:
|
|||||||
return AnyType(TypeOfAny.explicit)
|
return AnyType(TypeOfAny.explicit)
|
||||||
|
|
||||||
if lookup_cls is None or isinstance(lookup_cls, Exact):
|
if lookup_cls is None or isinstance(lookup_cls, Exact):
|
||||||
return self.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field)
|
return self.get_field_lookup_exact_type(chk_helpers.get_typechecker_api(ctx), field)
|
||||||
|
|
||||||
assert lookup_cls is not None
|
assert lookup_cls is not None
|
||||||
|
|
||||||
lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls)
|
lookup_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), lookup_cls)
|
||||||
if lookup_info is None:
|
if lookup_info is None:
|
||||||
return AnyType(TypeOfAny.explicit)
|
return AnyType(TypeOfAny.explicit)
|
||||||
|
|
||||||
@@ -370,7 +370,7 @@ class DjangoContext:
|
|||||||
# if it's Field, consider lookup_type a __get__ of current field
|
# if it's Field, consider lookup_type a __get__ of current field
|
||||||
if (isinstance(lookup_type, Instance)
|
if (isinstance(lookup_type, Instance)
|
||||||
and lookup_type.type.fullname == fullnames.FIELD_FULLNAME):
|
and lookup_type.type.fullname == fullnames.FIELD_FULLNAME):
|
||||||
field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__)
|
field_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), field.__class__)
|
||||||
if field_info is None:
|
if field_info is None:
|
||||||
return AnyType(TypeOfAny.explicit)
|
return AnyType(TypeOfAny.explicit)
|
||||||
lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||||
|
|||||||
112
mypy_django_plugin/lib/chk_helpers.py
Normal file
112
mypy_django_plugin/lib/chk_helpers.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
from typing import OrderedDict, List, Optional, Dict, Set, Union
|
||||||
|
|
||||||
|
from mypy import checker
|
||||||
|
from mypy.checker import TypeChecker
|
||||||
|
from mypy.nodes import MypyFile, TypeInfo, Var, MDEF, SymbolTableNode, GDEF, Expression
|
||||||
|
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext, AttributeContext
|
||||||
|
from mypy.types import Type as MypyType, Instance, TupleType, TypeOfAny, AnyType, TypedDictType
|
||||||
|
|
||||||
|
from mypy_django_plugin.lib import helpers
|
||||||
|
|
||||||
|
|
||||||
|
def add_new_class_for_current_module(current_module: MypyFile,
|
||||||
|
name: str,
|
||||||
|
bases: List[Instance],
|
||||||
|
fields: Optional[Dict[str, MypyType]] = None
|
||||||
|
) -> TypeInfo:
|
||||||
|
new_class_unique_name = checker.gen_unique_name(name, current_module.names)
|
||||||
|
new_typeinfo = helpers.new_typeinfo(new_class_unique_name,
|
||||||
|
bases=bases,
|
||||||
|
module_name=current_module.fullname)
|
||||||
|
# new_typeinfo = helpers.make_new_typeinfo_in_current_module(new_class_unique_name,
|
||||||
|
# bases=bases,
|
||||||
|
# current_module_fullname=current_module.fullname)
|
||||||
|
# add fields
|
||||||
|
if fields:
|
||||||
|
for field_name, field_type in fields.items():
|
||||||
|
var = Var(field_name, type=field_type)
|
||||||
|
var.info = new_typeinfo
|
||||||
|
var._fullname = new_typeinfo.fullname + '.' + field_name
|
||||||
|
new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True)
|
||||||
|
|
||||||
|
current_module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)
|
||||||
|
current_module.defs.append(new_typeinfo.defn)
|
||||||
|
return new_typeinfo
|
||||||
|
|
||||||
|
|
||||||
|
def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'Dict[str, MypyType]') -> TupleType:
|
||||||
|
current_module = helpers.get_current_module(api)
|
||||||
|
namedtuple_info = add_new_class_for_current_module(current_module, name,
|
||||||
|
bases=[api.named_generic_type('typing.NamedTuple', [])],
|
||||||
|
fields=fields)
|
||||||
|
return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, []))
|
||||||
|
|
||||||
|
|
||||||
|
def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType:
|
||||||
|
# fallback for tuples is any builtins.tuple instance
|
||||||
|
fallback = api.named_generic_type('builtins.tuple',
|
||||||
|
[AnyType(TypeOfAny.special_form)])
|
||||||
|
return TupleType(fields, fallback=fallback)
|
||||||
|
|
||||||
|
|
||||||
|
def make_oneoff_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyType]',
|
||||||
|
required_keys: Set[str]) -> TypedDictType:
|
||||||
|
object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
|
||||||
|
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
|
||||||
|
return typed_dict_type
|
||||||
|
|
||||||
|
|
||||||
|
def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker:
|
||||||
|
if not isinstance(ctx.api, TypeChecker):
|
||||||
|
raise ValueError('Not a TypeChecker')
|
||||||
|
return ctx.api
|
||||||
|
|
||||||
|
|
||||||
|
def check_types_compatible(ctx: Union[FunctionContext, MethodContext],
|
||||||
|
*, expected_type: MypyType, actual_type: MypyType, error_message: str) -> None:
|
||||||
|
api = get_typechecker_api(ctx)
|
||||||
|
api.check_subtype(actual_type, expected_type,
|
||||||
|
ctx.context, error_message,
|
||||||
|
'got', 'expected')
|
||||||
|
|
||||||
|
|
||||||
|
def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]:
|
||||||
|
"""
|
||||||
|
Return the expression for the specific argument.
|
||||||
|
This helper should only be used with non-star arguments.
|
||||||
|
"""
|
||||||
|
if name not in ctx.callee_arg_names:
|
||||||
|
return None
|
||||||
|
idx = ctx.callee_arg_names.index(name)
|
||||||
|
args = ctx.args[idx]
|
||||||
|
if len(args) != 1:
|
||||||
|
# Either an error or no value passed.
|
||||||
|
return None
|
||||||
|
return args[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]:
|
||||||
|
"""Return the type for the specific argument.
|
||||||
|
|
||||||
|
This helper should only be used with non-star arguments.
|
||||||
|
"""
|
||||||
|
if name not in ctx.callee_arg_names:
|
||||||
|
return None
|
||||||
|
idx = ctx.callee_arg_names.index(name)
|
||||||
|
arg_types = ctx.arg_types[idx]
|
||||||
|
if len(arg_types) != 1:
|
||||||
|
# Either an error or no value passed.
|
||||||
|
return None
|
||||||
|
return arg_types[0]
|
||||||
|
|
||||||
|
|
||||||
|
def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None:
|
||||||
|
# type=: type of the variable itself
|
||||||
|
var = Var(name=name, type=sym_type)
|
||||||
|
# var.info: type of the object variable is bound to
|
||||||
|
var.info = info
|
||||||
|
var._fullname = info.fullname + '.' + name
|
||||||
|
var.is_initialized_in_class = True
|
||||||
|
var.is_inferred = True
|
||||||
|
info.names[name] = SymbolTableNode(MDEF, var,
|
||||||
|
plugin_generated=True)
|
||||||
@@ -1,41 +1,33 @@
|
|||||||
from collections import OrderedDict
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union,
|
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from django.db.models.fields import Field
|
|
||||||
from django.db.models.fields.related import RelatedField
|
from django.db.models.fields.related import RelatedField
|
||||||
from django.db.models.fields.reverse_related import ForeignObjectRel
|
from django.db.models.fields.reverse_related import ForeignObjectRel
|
||||||
from mypy import checker
|
|
||||||
from mypy.checker import TypeChecker
|
from mypy.checker import TypeChecker
|
||||||
from mypy.mro import calculate_mro
|
from mypy.mro import calculate_mro
|
||||||
from mypy.nodes import (
|
from mypy.nodes import (
|
||||||
GDEF, MDEF, Argument, Block, ClassDef, Expression, FuncDef, MemberExpr, MypyFile, NameExpr, PlaceholderNode,
|
Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode,
|
||||||
StrExpr, SymbolNode, SymbolTable, SymbolTableNode, TypeInfo, Var,
|
SymbolTable, SymbolTableNode, TypeInfo, Var,
|
||||||
)
|
)
|
||||||
from mypy.plugin import (
|
|
||||||
AttributeContext, CheckerPluginInterface, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext,
|
|
||||||
)
|
|
||||||
from mypy.plugins.common import add_method
|
|
||||||
from mypy.semanal import SemanticAnalyzer
|
from mypy.semanal import SemanticAnalyzer
|
||||||
from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType
|
from mypy.types import AnyType, Instance, NoneTyp
|
||||||
from mypy.types import Type as MypyType
|
from mypy.types import Type as MypyType
|
||||||
from mypy.types import TypedDictType, TypeOfAny, UnionType
|
from mypy.types import TypeOfAny, UnionType
|
||||||
|
|
||||||
|
from django.db.models.fields import Field
|
||||||
from mypy_django_plugin.lib import fullnames
|
from mypy_django_plugin.lib import fullnames
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from mypy_django_plugin.django.context import DjangoContext
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
|
|
||||||
|
AnyPluginAPI = Union[TypeChecker, SemanticAnalyzer]
|
||||||
|
|
||||||
|
|
||||||
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
|
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
|
||||||
return model_info.metadata.setdefault('django', {})
|
return model_info.metadata.setdefault('django', {})
|
||||||
|
|
||||||
|
|
||||||
class IncompleteDefnException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]:
|
def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]:
|
||||||
if '.' not in fullname:
|
if '.' not in fullname:
|
||||||
return None
|
return None
|
||||||
@@ -57,14 +49,14 @@ def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile])
|
|||||||
return sym.node
|
return sym.node
|
||||||
|
|
||||||
|
|
||||||
def lookup_fully_qualified_typeinfo(api: Union[TypeChecker, SemanticAnalyzer], fullname: str) -> Optional[TypeInfo]:
|
def lookup_fully_qualified_typeinfo(api: AnyPluginAPI, fullname: str) -> Optional[TypeInfo]:
|
||||||
node = lookup_fully_qualified_generic(fullname, api.modules)
|
node = lookup_fully_qualified_generic(fullname, api.modules)
|
||||||
if not isinstance(node, TypeInfo):
|
if not isinstance(node, TypeInfo):
|
||||||
return None
|
return None
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
def lookup_class_typeinfo(api: TypeChecker, klass: type) -> Optional[TypeInfo]:
|
def lookup_class_typeinfo(api: AnyPluginAPI, klass: type) -> Optional[TypeInfo]:
|
||||||
fullname = get_class_fullname(klass)
|
fullname = get_class_fullname(klass)
|
||||||
field_info = lookup_fully_qualified_typeinfo(api, fullname)
|
field_info = lookup_fully_qualified_typeinfo(api, fullname)
|
||||||
return field_info
|
return field_info
|
||||||
@@ -79,36 +71,6 @@ def get_class_fullname(klass: type) -> str:
|
|||||||
return klass.__module__ + '.' + klass.__qualname__
|
return klass.__module__ + '.' + klass.__qualname__
|
||||||
|
|
||||||
|
|
||||||
def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]:
|
|
||||||
"""
|
|
||||||
Return the expression for the specific argument.
|
|
||||||
This helper should only be used with non-star arguments.
|
|
||||||
"""
|
|
||||||
if name not in ctx.callee_arg_names:
|
|
||||||
return None
|
|
||||||
idx = ctx.callee_arg_names.index(name)
|
|
||||||
args = ctx.args[idx]
|
|
||||||
if len(args) != 1:
|
|
||||||
# Either an error or no value passed.
|
|
||||||
return None
|
|
||||||
return args[0]
|
|
||||||
|
|
||||||
|
|
||||||
def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]:
|
|
||||||
"""Return the type for the specific argument.
|
|
||||||
|
|
||||||
This helper should only be used with non-star arguments.
|
|
||||||
"""
|
|
||||||
if name not in ctx.callee_arg_names:
|
|
||||||
return None
|
|
||||||
idx = ctx.callee_arg_names.index(name)
|
|
||||||
arg_types = ctx.arg_types[idx]
|
|
||||||
if len(arg_types) != 1:
|
|
||||||
# Either an error or no value passed.
|
|
||||||
return None
|
|
||||||
return arg_types[0]
|
|
||||||
|
|
||||||
|
|
||||||
def make_optional(typ: MypyType) -> MypyType:
|
def make_optional(typ: MypyType) -> MypyType:
|
||||||
return UnionType.make_union([typ, NoneTyp()])
|
return UnionType.make_union([typ, NoneTyp()])
|
||||||
|
|
||||||
@@ -153,7 +115,7 @@ def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is
|
|||||||
return AnyType(TypeOfAny.explicit)
|
return AnyType(TypeOfAny.explicit)
|
||||||
|
|
||||||
|
|
||||||
def get_field_lookup_exact_type(api: TypeChecker, field: Field) -> MypyType:
|
def get_field_lookup_exact_type(api: AnyPluginAPI, field: Field) -> MypyType:
|
||||||
if isinstance(field, (RelatedField, ForeignObjectRel)):
|
if isinstance(field, (RelatedField, ForeignObjectRel)):
|
||||||
lookup_type_class = field.related_model
|
lookup_type_class = field.related_model
|
||||||
rel_model_info = lookup_class_typeinfo(api, lookup_type_class)
|
rel_model_info = lookup_class_typeinfo(api, lookup_type_class)
|
||||||
@@ -168,44 +130,10 @@ def get_field_lookup_exact_type(api: TypeChecker, field: Field) -> MypyType:
|
|||||||
is_nullable=field.null)
|
is_nullable=field.null)
|
||||||
|
|
||||||
|
|
||||||
def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]:
|
def get_current_module(api: AnyPluginAPI) -> MypyFile:
|
||||||
metaclass_sym = info.names.get('Meta')
|
if isinstance(api, SemanticAnalyzer):
|
||||||
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
|
return api.cur_mod_node
|
||||||
return metaclass_sym.node
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def add_new_class_for_module(module: MypyFile,
|
|
||||||
name: str,
|
|
||||||
bases: List[Instance],
|
|
||||||
fields: Optional[Dict[str, MypyType]] = None
|
|
||||||
) -> TypeInfo:
|
|
||||||
new_class_unique_name = checker.gen_unique_name(name, module.names)
|
|
||||||
|
|
||||||
# make new class expression
|
|
||||||
classdef = ClassDef(new_class_unique_name, Block([]))
|
|
||||||
classdef.fullname = module.fullname + '.' + new_class_unique_name
|
|
||||||
|
|
||||||
# make new TypeInfo
|
|
||||||
new_typeinfo = TypeInfo(SymbolTable(), classdef, module.fullname)
|
|
||||||
new_typeinfo.bases = bases
|
|
||||||
calculate_mro(new_typeinfo)
|
|
||||||
new_typeinfo.calculate_metaclass_type()
|
|
||||||
|
|
||||||
# add fields
|
|
||||||
if fields:
|
|
||||||
for field_name, field_type in fields.items():
|
|
||||||
var = Var(field_name, type=field_type)
|
|
||||||
var.info = new_typeinfo
|
|
||||||
var._fullname = new_typeinfo.fullname + '.' + field_name
|
|
||||||
new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True)
|
|
||||||
|
|
||||||
classdef.info = new_typeinfo
|
|
||||||
module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)
|
|
||||||
return new_typeinfo
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_module(api: TypeChecker) -> MypyFile:
|
|
||||||
current_module = None
|
current_module = None
|
||||||
for item in reversed(api.scope.stack):
|
for item in reversed(api.scope.stack):
|
||||||
if isinstance(item, MypyFile):
|
if isinstance(item, MypyFile):
|
||||||
@@ -215,21 +143,6 @@ def get_current_module(api: TypeChecker) -> MypyFile:
|
|||||||
return current_module
|
return current_module
|
||||||
|
|
||||||
|
|
||||||
def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType:
|
|
||||||
current_module = get_current_module(api)
|
|
||||||
namedtuple_info = add_new_class_for_module(current_module, name,
|
|
||||||
bases=[api.named_generic_type('typing.NamedTuple', [])],
|
|
||||||
fields=fields)
|
|
||||||
return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, []))
|
|
||||||
|
|
||||||
|
|
||||||
def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType:
|
|
||||||
# fallback for tuples is any builtins.tuple instance
|
|
||||||
fallback = api.named_generic_type('builtins.tuple',
|
|
||||||
[AnyType(TypeOfAny.special_form)])
|
|
||||||
return TupleType(fields, fallback=fallback)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
|
def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
|
||||||
if isinstance(typ, UnionType):
|
if isinstance(typ, UnionType):
|
||||||
converted_items = []
|
converted_items = []
|
||||||
@@ -252,13 +165,6 @@ def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
|
|||||||
return typ
|
return typ
|
||||||
|
|
||||||
|
|
||||||
def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyType]',
|
|
||||||
required_keys: Set[str]) -> TypedDictType:
|
|
||||||
object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
|
|
||||||
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
|
|
||||||
return typed_dict_type
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_string_attribute_value(attr_expr: Expression, django_context: 'DjangoContext') -> Optional[str]:
|
def resolve_string_attribute_value(attr_expr: Expression, django_context: 'DjangoContext') -> Optional[str]:
|
||||||
if isinstance(attr_expr, StrExpr):
|
if isinstance(attr_expr, StrExpr):
|
||||||
return attr_expr.value
|
return attr_expr.value
|
||||||
@@ -272,104 +178,25 @@ def resolve_string_attribute_value(attr_expr: Expression, django_context: 'Djang
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_semanal_api(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> SemanticAnalyzer:
|
def is_subclass_of_model(info: TypeInfo, django_context: 'DjangoContext') -> bool:
|
||||||
if not isinstance(ctx.api, SemanticAnalyzer):
|
|
||||||
raise ValueError('Not a SemanticAnalyzer')
|
|
||||||
return ctx.api
|
|
||||||
|
|
||||||
|
|
||||||
def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker:
|
|
||||||
if not isinstance(ctx.api, TypeChecker):
|
|
||||||
raise ValueError('Not a TypeChecker')
|
|
||||||
return ctx.api
|
|
||||||
|
|
||||||
|
|
||||||
def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool:
|
|
||||||
return (info.fullname in django_context.all_registered_model_class_fullnames
|
return (info.fullname in django_context.all_registered_model_class_fullnames
|
||||||
or info.has_base(fullnames.MODEL_CLASS_FULLNAME))
|
or info.has_base(fullnames.MODEL_CLASS_FULLNAME))
|
||||||
|
|
||||||
|
|
||||||
def check_types_compatible(ctx: Union[FunctionContext, MethodContext],
|
def new_typeinfo(name: str,
|
||||||
*, expected_type: MypyType, actual_type: MypyType, error_message: str) -> None:
|
*,
|
||||||
api = get_typechecker_api(ctx)
|
bases: List[Instance],
|
||||||
api.check_subtype(actual_type, expected_type,
|
module_name: str) -> TypeInfo:
|
||||||
ctx.context, error_message,
|
"""
|
||||||
'got', 'expected')
|
Construct new TypeInfo instance. Cannot be used for nested classes.
|
||||||
|
"""
|
||||||
|
class_def = ClassDef(name, Block([]))
|
||||||
|
class_def.fullname = module_name + '.' + name
|
||||||
|
|
||||||
|
info = TypeInfo(SymbolTable(), class_def, module_name)
|
||||||
|
info.bases = bases
|
||||||
|
calculate_mro(info)
|
||||||
|
info.metaclass_type = info.calculate_metaclass_type()
|
||||||
|
|
||||||
def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None:
|
class_def.info = info
|
||||||
# type=: type of the variable itself
|
return info
|
||||||
var = Var(name=name, type=sym_type)
|
|
||||||
# var.info: type of the object variable is bound to
|
|
||||||
var.info = info
|
|
||||||
var._fullname = info.fullname + '.' + name
|
|
||||||
var.is_initialized_in_class = True
|
|
||||||
var.is_inferred = True
|
|
||||||
info.names[name] = SymbolTableNode(MDEF, var,
|
|
||||||
plugin_generated=True)
|
|
||||||
|
|
||||||
|
|
||||||
def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument], MypyType]:
|
|
||||||
prepared_arguments = []
|
|
||||||
for argument in method_node.arguments[1:]:
|
|
||||||
argument.type_annotation = AnyType(TypeOfAny.unannotated)
|
|
||||||
prepared_arguments.append(argument)
|
|
||||||
return_type = AnyType(TypeOfAny.unannotated)
|
|
||||||
return prepared_arguments, return_type
|
|
||||||
|
|
||||||
|
|
||||||
def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance,
|
|
||||||
new_method_name: str, method_node: FuncDef) -> None:
|
|
||||||
semanal_api = get_semanal_api(ctx)
|
|
||||||
if method_node.type is None:
|
|
||||||
if not semanal_api.final_iteration:
|
|
||||||
semanal_api.defer()
|
|
||||||
return
|
|
||||||
|
|
||||||
arguments, return_type = build_unannotated_method_args(method_node)
|
|
||||||
add_method(ctx,
|
|
||||||
new_method_name,
|
|
||||||
args=arguments,
|
|
||||||
return_type=return_type,
|
|
||||||
self_type=self_type)
|
|
||||||
return
|
|
||||||
|
|
||||||
method_type = method_node.type
|
|
||||||
if not isinstance(method_type, CallableType):
|
|
||||||
if not semanal_api.final_iteration:
|
|
||||||
semanal_api.defer()
|
|
||||||
return
|
|
||||||
|
|
||||||
arguments = []
|
|
||||||
bound_return_type = semanal_api.anal_type(method_type.ret_type,
|
|
||||||
allow_placeholder=True)
|
|
||||||
assert bound_return_type is not None
|
|
||||||
|
|
||||||
if isinstance(bound_return_type, PlaceholderNode):
|
|
||||||
return
|
|
||||||
|
|
||||||
for arg_name, arg_type, original_argument in zip(method_type.arg_names[1:],
|
|
||||||
method_type.arg_types[1:],
|
|
||||||
method_node.arguments[1:]):
|
|
||||||
bound_arg_type = semanal_api.anal_type(arg_type, allow_placeholder=True)
|
|
||||||
assert bound_arg_type is not None
|
|
||||||
|
|
||||||
if isinstance(bound_arg_type, PlaceholderNode):
|
|
||||||
return
|
|
||||||
|
|
||||||
var = Var(name=original_argument.variable.name,
|
|
||||||
type=arg_type)
|
|
||||||
var.line = original_argument.variable.line
|
|
||||||
var.column = original_argument.variable.column
|
|
||||||
argument = Argument(variable=var,
|
|
||||||
type_annotation=bound_arg_type,
|
|
||||||
initializer=original_argument.initializer,
|
|
||||||
kind=original_argument.kind)
|
|
||||||
argument.set_line(original_argument)
|
|
||||||
arguments.append(argument)
|
|
||||||
|
|
||||||
add_method(ctx,
|
|
||||||
new_method_name,
|
|
||||||
args=arguments,
|
|
||||||
return_type=bound_return_type,
|
|
||||||
self_type=self_type)
|
|
||||||
|
|||||||
117
mypy_django_plugin/lib/sem_helpers.py
Normal file
117
mypy_django_plugin/lib/sem_helpers.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
from typing import Union, Tuple, List, Optional, NamedTuple, cast
|
||||||
|
|
||||||
|
from mypy.nodes import Argument, FuncDef, Var, TypeInfo
|
||||||
|
from mypy.plugin import DynamicClassDefContext, ClassDefContext
|
||||||
|
from mypy.plugins.common import add_method
|
||||||
|
from mypy.semanal import SemanticAnalyzer
|
||||||
|
from mypy.types import Instance, CallableType, AnyType, TypeOfAny, PlaceholderType
|
||||||
|
from mypy.types import Type as MypyType
|
||||||
|
|
||||||
|
|
||||||
|
class IncompleteDefnException(Exception):
|
||||||
|
def __init__(self, error_message: str = '') -> None:
|
||||||
|
super().__init__(error_message)
|
||||||
|
|
||||||
|
|
||||||
|
class BoundNameNotFound(IncompleteDefnException):
|
||||||
|
def __init__(self, fullname: str) -> None:
|
||||||
|
super().__init__(f'No {fullname!r} found')
|
||||||
|
|
||||||
|
|
||||||
|
def get_semanal_api(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> SemanticAnalyzer:
|
||||||
|
return cast(SemanticAnalyzer, ctx.api)
|
||||||
|
|
||||||
|
|
||||||
|
def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]:
|
||||||
|
metaclass_sym = info.names.get('Meta')
|
||||||
|
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
|
||||||
|
return metaclass_sym.node
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_unannotated_method_signature(method_node: FuncDef) -> Tuple[List[Argument], MypyType]:
|
||||||
|
prepared_arguments = []
|
||||||
|
for argument in method_node.arguments[1:]:
|
||||||
|
argument.type_annotation = AnyType(TypeOfAny.unannotated)
|
||||||
|
prepared_arguments.append(argument)
|
||||||
|
return_type = AnyType(TypeOfAny.unannotated)
|
||||||
|
return prepared_arguments, return_type
|
||||||
|
|
||||||
|
|
||||||
|
class SignatureTuple(NamedTuple):
|
||||||
|
arguments: Optional[List[Argument]]
|
||||||
|
return_type: Optional[MypyType]
|
||||||
|
cannot_be_bound: bool
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_callable_signature(api: SemanticAnalyzer, method_node: FuncDef) -> SignatureTuple:
|
||||||
|
method_type = method_node.type
|
||||||
|
assert isinstance(method_type, CallableType)
|
||||||
|
|
||||||
|
arguments = []
|
||||||
|
unbound = False
|
||||||
|
for arg_name, arg_type, original_argument in zip(method_type.arg_names[1:],
|
||||||
|
method_type.arg_types[1:],
|
||||||
|
method_node.arguments[1:]):
|
||||||
|
arg_type = api.anal_type(arg_type, allow_placeholder=True)
|
||||||
|
if isinstance(arg_type, PlaceholderType):
|
||||||
|
unbound = True
|
||||||
|
|
||||||
|
var = Var(name=original_argument.variable.name,
|
||||||
|
type=arg_type)
|
||||||
|
var.set_line(original_argument.variable)
|
||||||
|
|
||||||
|
if isinstance(arg_type, PlaceholderType):
|
||||||
|
unbound = True
|
||||||
|
argument = Argument(variable=var,
|
||||||
|
type_annotation=arg_type,
|
||||||
|
initializer=original_argument.initializer,
|
||||||
|
kind=original_argument.kind)
|
||||||
|
argument.set_line(original_argument)
|
||||||
|
arguments.append(argument)
|
||||||
|
|
||||||
|
ret_type = api.anal_type(method_type.ret_type, allow_placeholder=True)
|
||||||
|
if isinstance(ret_type, PlaceholderType):
|
||||||
|
unbound = True
|
||||||
|
return SignatureTuple(arguments, ret_type, unbound)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_method_or_incomplete_defn_exception(ctx: ClassDefContext,
|
||||||
|
self_type: Instance,
|
||||||
|
new_method_name: str,
|
||||||
|
method_node: FuncDef) -> None:
|
||||||
|
semanal_api = get_semanal_api(ctx)
|
||||||
|
|
||||||
|
if method_node.type is None:
|
||||||
|
if not semanal_api.final_iteration:
|
||||||
|
raise IncompleteDefnException(f'Unannotated method {method_node.fullname!r}')
|
||||||
|
|
||||||
|
arguments, return_type = prepare_unannotated_method_signature(method_node)
|
||||||
|
add_method(ctx,
|
||||||
|
new_method_name,
|
||||||
|
args=arguments,
|
||||||
|
return_type=return_type,
|
||||||
|
self_type=self_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
assert isinstance(method_node.type, CallableType)
|
||||||
|
|
||||||
|
# copy global SymbolTableNode objects from original class to the current node, if not present
|
||||||
|
original_module = semanal_api.modules[method_node.info.module_name]
|
||||||
|
for name, sym in original_module.names.items():
|
||||||
|
if (not sym.plugin_generated
|
||||||
|
and name not in semanal_api.cur_mod_node.names):
|
||||||
|
semanal_api.add_imported_symbol(name, sym, context=semanal_api.cur_mod_node)
|
||||||
|
|
||||||
|
arguments, return_type, unbound = analyze_callable_signature(semanal_api, method_node)
|
||||||
|
assert len(arguments) + 1 == len(method_node.arguments)
|
||||||
|
if unbound:
|
||||||
|
raise IncompleteDefnException(f'Signature of method {method_node.fullname!r} is not ready')
|
||||||
|
|
||||||
|
if new_method_name in ctx.cls.info.names:
|
||||||
|
del ctx.cls.info.names[new_method_name]
|
||||||
|
add_method(ctx,
|
||||||
|
new_method_name,
|
||||||
|
args=arguments,
|
||||||
|
return_type=return_type,
|
||||||
|
self_type=self_type)
|
||||||
@@ -9,6 +9,7 @@ from mypy.options import Options
|
|||||||
from mypy.plugin import (
|
from mypy.plugin import (
|
||||||
AttributeContext, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext, Plugin,
|
AttributeContext, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext, Plugin,
|
||||||
)
|
)
|
||||||
|
from mypy.semanal import dummy_context
|
||||||
from mypy.types import Type as MypyType
|
from mypy.types import Type as MypyType
|
||||||
|
|
||||||
import mypy_django_plugin.transformers.orm_lookups
|
import mypy_django_plugin.transformers.orm_lookups
|
||||||
@@ -30,7 +31,7 @@ def transform_model_class(ctx: ClassDefContext,
|
|||||||
if sym is not None and isinstance(sym.node, TypeInfo):
|
if sym is not None and isinstance(sym.node, TypeInfo):
|
||||||
helpers.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1
|
helpers.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1
|
||||||
else:
|
else:
|
||||||
if not ctx.api.final_iteration:
|
if not ctx.api.final_iteration and not ctx.api.deferred:
|
||||||
ctx.api.defer()
|
ctx.api.defer()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -180,7 +181,7 @@ class NewSemanalDjangoPlugin(Plugin):
|
|||||||
if info.has_base(fullnames.FIELD_FULLNAME):
|
if info.has_base(fullnames.FIELD_FULLNAME):
|
||||||
return partial(fields.transform_into_proper_return_type, django_context=self.django_context)
|
return partial(fields.transform_into_proper_return_type, django_context=self.django_context)
|
||||||
|
|
||||||
if helpers.is_model_subclass_info(info, self.django_context):
|
if helpers.is_subclass_of_model(info, self.django_context):
|
||||||
return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)
|
return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ from mypy.types import Type as MypyType
|
|||||||
from mypy.types import TypeOfAny
|
from mypy.types import TypeOfAny
|
||||||
|
|
||||||
from mypy_django_plugin.django.context import DjangoContext
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
from mypy_django_plugin.lib import fullnames, helpers
|
from mypy_django_plugin.lib import fullnames, helpers, chk_helpers
|
||||||
|
|
||||||
|
|
||||||
def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]:
|
def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]:
|
||||||
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
|
outer_model_info = chk_helpers.get_typechecker_api(ctx).scope.active_class()
|
||||||
if (outer_model_info is None
|
if (outer_model_info is None
|
||||||
or not helpers.is_model_subclass_info(outer_model_info, django_context)):
|
or not helpers.is_subclass_of_model(outer_model_info, django_context)):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
field_name = None
|
field_name = None
|
||||||
@@ -66,12 +66,12 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
|
|||||||
# __get__/__set__ of ForeignKey of derived model
|
# __get__/__set__ of ForeignKey of derived model
|
||||||
for model_cls in django_context.all_registered_model_classes:
|
for model_cls in django_context.all_registered_model_classes:
|
||||||
if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract:
|
if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract:
|
||||||
derived_model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls)
|
derived_model_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), model_cls)
|
||||||
if derived_model_info is not None:
|
if derived_model_info is not None:
|
||||||
fk_ref_type = Instance(derived_model_info, [])
|
fk_ref_type = Instance(derived_model_info, [])
|
||||||
derived_fk_type = reparametrize_related_field_type(default_related_field_type,
|
derived_fk_type = reparametrize_related_field_type(default_related_field_type,
|
||||||
set_type=fk_ref_type, get_type=fk_ref_type)
|
set_type=fk_ref_type, get_type=fk_ref_type)
|
||||||
helpers.add_new_sym_for_info(derived_model_info,
|
chk_helpers.add_new_sym_for_info(derived_model_info,
|
||||||
name=current_field.name,
|
name=current_field.name,
|
||||||
sym_type=derived_fk_type)
|
sym_type=derived_fk_type)
|
||||||
|
|
||||||
@@ -80,7 +80,7 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
|
|||||||
if related_model_to_set._meta.proxy_for_model is not None:
|
if related_model_to_set._meta.proxy_for_model is not None:
|
||||||
related_model_to_set = related_model_to_set._meta.proxy_for_model
|
related_model_to_set = related_model_to_set._meta.proxy_for_model
|
||||||
|
|
||||||
typechecker_api = helpers.get_typechecker_api(ctx)
|
typechecker_api = chk_helpers.get_typechecker_api(ctx)
|
||||||
|
|
||||||
related_model_info = helpers.lookup_class_typeinfo(typechecker_api, related_model)
|
related_model_info = helpers.lookup_class_typeinfo(typechecker_api, related_model)
|
||||||
if related_model_info is None:
|
if related_model_info is None:
|
||||||
@@ -114,7 +114,7 @@ def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
|
|||||||
default_return_type = cast(Instance, ctx.default_return_type)
|
default_return_type = cast(Instance, ctx.default_return_type)
|
||||||
|
|
||||||
is_nullable = False
|
is_nullable = False
|
||||||
null_expr = helpers.get_call_argument_by_name(ctx, 'null')
|
null_expr = chk_helpers.get_call_argument_by_name(ctx, 'null')
|
||||||
if null_expr is not None:
|
if null_expr is not None:
|
||||||
is_nullable = helpers.parse_bool(null_expr) or False
|
is_nullable = helpers.parse_bool(null_expr) or False
|
||||||
|
|
||||||
@@ -122,10 +122,10 @@ def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
|
|||||||
return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
|
return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
|
||||||
|
|
||||||
|
|
||||||
def determine_type_of_array_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
|
def determine_type_of_array_field(ctx: FunctionContext) -> MypyType:
|
||||||
default_return_type = set_descriptor_types_for_field(ctx)
|
default_return_type = set_descriptor_types_for_field(ctx)
|
||||||
|
|
||||||
base_field_arg_type = helpers.get_call_argument_type_by_name(ctx, 'base_field')
|
base_field_arg_type = chk_helpers.get_call_argument_type_by_name(ctx, 'base_field')
|
||||||
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
|
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
|
||||||
return default_return_type
|
return default_return_type
|
||||||
|
|
||||||
@@ -141,9 +141,9 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan
|
|||||||
default_return_type = ctx.default_return_type
|
default_return_type = ctx.default_return_type
|
||||||
assert isinstance(default_return_type, Instance)
|
assert isinstance(default_return_type, Instance)
|
||||||
|
|
||||||
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
|
outer_model_info = chk_helpers.get_typechecker_api(ctx).scope.active_class()
|
||||||
if (outer_model_info is None
|
if (outer_model_info is None
|
||||||
or not helpers.is_model_subclass_info(outer_model_info, django_context)):
|
or not helpers.is_subclass_of_model(outer_model_info, django_context)):
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
assert isinstance(outer_model_info, TypeInfo)
|
assert isinstance(outer_model_info, TypeInfo)
|
||||||
@@ -152,6 +152,6 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan
|
|||||||
return fill_descriptor_types_for_related_field(ctx, django_context)
|
return fill_descriptor_types_for_related_field(ctx, django_context)
|
||||||
|
|
||||||
if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME):
|
if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME):
|
||||||
return determine_type_of_array_field(ctx, django_context)
|
return determine_type_of_array_field(ctx)
|
||||||
|
|
||||||
return set_descriptor_types_for_field(ctx)
|
return set_descriptor_types_for_field(ctx)
|
||||||
|
|||||||
@@ -5,11 +5,11 @@ from mypy.types import CallableType, Instance, NoneTyp
|
|||||||
from mypy.types import Type as MypyType
|
from mypy.types import Type as MypyType
|
||||||
from mypy.types import TypeType
|
from mypy.types import TypeType
|
||||||
|
|
||||||
from mypy_django_plugin.lib import helpers
|
from mypy_django_plugin.lib import sem_helpers, chk_helpers
|
||||||
|
|
||||||
|
|
||||||
def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None:
|
def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None:
|
||||||
meta_node = helpers.get_nested_meta_node_for_current_class(ctx.cls.info)
|
meta_node = sem_helpers.get_nested_meta_node_for_current_class(ctx.cls.info)
|
||||||
if meta_node is None:
|
if meta_node is None:
|
||||||
if not ctx.api.final_iteration:
|
if not ctx.api.final_iteration:
|
||||||
ctx.api.defer()
|
ctx.api.defer()
|
||||||
@@ -28,7 +28,7 @@ def extract_proper_type_for_get_form(ctx: MethodContext) -> MypyType:
|
|||||||
object_type = ctx.type
|
object_type = ctx.type
|
||||||
assert isinstance(object_type, Instance)
|
assert isinstance(object_type, Instance)
|
||||||
|
|
||||||
form_class_type = helpers.get_call_argument_type_by_name(ctx, 'form_class')
|
form_class_type = chk_helpers.get_call_argument_type_by_name(ctx, 'form_class')
|
||||||
if form_class_type is None or isinstance(form_class_type, NoneTyp):
|
if form_class_type is None or isinstance(form_class_type, NoneTyp):
|
||||||
form_class_type = get_specified_form_class(object_type)
|
form_class_type = get_specified_form_class(object_type)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from mypy.types import Instance
|
|||||||
from mypy.types import Type as MypyType
|
from mypy.types import Type as MypyType
|
||||||
|
|
||||||
from mypy_django_plugin.django.context import DjangoContext
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
from mypy_django_plugin.lib import helpers
|
from mypy_django_plugin.lib import chk_helpers
|
||||||
|
|
||||||
|
|
||||||
def get_actual_types(ctx: Union[MethodContext, FunctionContext],
|
def get_actual_types(ctx: Union[MethodContext, FunctionContext],
|
||||||
@@ -32,7 +32,7 @@ def get_actual_types(ctx: Union[MethodContext, FunctionContext],
|
|||||||
|
|
||||||
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
|
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
|
||||||
model_cls: Type[Model], method: str) -> MypyType:
|
model_cls: Type[Model], method: str) -> MypyType:
|
||||||
typechecker_api = helpers.get_typechecker_api(ctx)
|
typechecker_api = chk_helpers.get_typechecker_api(ctx)
|
||||||
expected_types = django_context.get_expected_types(typechecker_api, model_cls, method=method)
|
expected_types = django_context.get_expected_types(typechecker_api, model_cls, method=method)
|
||||||
expected_keys = [key for key in expected_types.keys() if key != 'pk']
|
expected_keys = [key for key in expected_types.keys() if key != 'pk']
|
||||||
|
|
||||||
@@ -42,11 +42,11 @@ def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_co
|
|||||||
model_cls.__name__),
|
model_cls.__name__),
|
||||||
ctx.context)
|
ctx.context)
|
||||||
continue
|
continue
|
||||||
helpers.check_types_compatible(ctx,
|
error_message = 'Incompatible type for "{}" of "{}"'.format(actual_name, model_cls.__name__)
|
||||||
|
chk_helpers.check_types_compatible(ctx,
|
||||||
expected_type=expected_types[actual_name],
|
expected_type=expected_types[actual_name],
|
||||||
actual_type=actual_type,
|
actual_type=actual_type,
|
||||||
error_message='Incompatible type for "{}" of "{}"'.format(actual_name,
|
error_message=error_message)
|
||||||
model_cls.__name__))
|
|
||||||
|
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
|||||||
@@ -1,77 +1,149 @@
|
|||||||
|
from typing import Iterator, Tuple, Optional
|
||||||
|
|
||||||
from mypy.nodes import (
|
from mypy.nodes import (
|
||||||
GDEF, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo,
|
FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, TypeInfo,
|
||||||
|
PlaceholderNode, SymbolTableNode, GDEF
|
||||||
)
|
)
|
||||||
from mypy.plugin import ClassDefContext, DynamicClassDefContext
|
from mypy.plugin import ClassDefContext, DynamicClassDefContext
|
||||||
from mypy.types import AnyType, Instance, TypeOfAny
|
from mypy.types import AnyType, Instance, TypeOfAny
|
||||||
|
from mypy.typevars import fill_typevars
|
||||||
|
|
||||||
from mypy_django_plugin.lib import fullnames, helpers
|
from mypy_django_plugin.lib import fullnames, sem_helpers, helpers
|
||||||
|
|
||||||
|
|
||||||
def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefContext) -> None:
|
def iter_all_custom_queryset_methods(derived_queryset_info: TypeInfo) -> Iterator[Tuple[str, FuncDef]]:
|
||||||
semanal_api = helpers.get_semanal_api(ctx)
|
for base_queryset_info in derived_queryset_info.mro:
|
||||||
|
if base_queryset_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME:
|
||||||
|
break
|
||||||
|
for name, sym in base_queryset_info.names.items():
|
||||||
|
if isinstance(sym.node, FuncDef):
|
||||||
|
yield name, sym.node
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_callee_manager_info_or_exception(ctx: DynamicClassDefContext) -> Optional[TypeInfo]:
|
||||||
callee = ctx.call.callee
|
callee = ctx.call.callee
|
||||||
assert isinstance(callee, MemberExpr)
|
assert isinstance(callee, MemberExpr)
|
||||||
assert isinstance(callee.expr, RefExpr)
|
assert isinstance(callee.expr, RefExpr)
|
||||||
|
|
||||||
base_manager_info = callee.expr.node
|
callee_manager_info = callee.expr.node
|
||||||
if base_manager_info is None:
|
if (callee_manager_info is None
|
||||||
if not semanal_api.final_iteration:
|
or isinstance(callee_manager_info, PlaceholderNode)):
|
||||||
semanal_api.defer()
|
raise sem_helpers.IncompleteDefnException(f'Definition of base manager {callee_manager_info.fullname} '
|
||||||
return
|
f'is incomplete.')
|
||||||
|
|
||||||
assert isinstance(base_manager_info, TypeInfo)
|
assert isinstance(callee_manager_info, TypeInfo)
|
||||||
new_manager_info = semanal_api.basic_new_typeinfo(ctx.name,
|
return callee_manager_info
|
||||||
basetype_or_fallback=Instance(base_manager_info,
|
|
||||||
[AnyType(TypeOfAny.unannotated)]))
|
|
||||||
new_manager_info.line = ctx.call.line
|
|
||||||
new_manager_info.defn.line = ctx.call.line
|
|
||||||
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
|
|
||||||
|
|
||||||
current_module = semanal_api.cur_mod_node
|
|
||||||
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info,
|
|
||||||
plugin_generated=True)
|
|
||||||
passed_queryset = ctx.call.args[0]
|
|
||||||
assert isinstance(passed_queryset, NameExpr)
|
|
||||||
|
|
||||||
derived_queryset_fullname = passed_queryset.fullname
|
def resolve_passed_queryset_info_or_exception(ctx: DynamicClassDefContext) -> Optional[TypeInfo]:
|
||||||
assert derived_queryset_fullname is not None
|
api = sem_helpers.get_semanal_api(ctx)
|
||||||
|
|
||||||
sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname)
|
passed_queryset_name_expr = ctx.call.args[0]
|
||||||
assert sym is not None
|
assert isinstance(passed_queryset_name_expr, NameExpr)
|
||||||
if sym.node is None:
|
|
||||||
if not semanal_api.final_iteration:
|
|
||||||
semanal_api.defer()
|
|
||||||
else:
|
|
||||||
# inherit from Any to prevent false-positives, if queryset class cannot be resolved
|
|
||||||
new_manager_info.fallback_to_any = True
|
|
||||||
return
|
|
||||||
|
|
||||||
derived_queryset_info = sym.node
|
sym = api.lookup_qualified(passed_queryset_name_expr.name, ctx=ctx.call)
|
||||||
assert isinstance(derived_queryset_info, TypeInfo)
|
if (sym is None
|
||||||
|
or sym.node is None
|
||||||
|
or isinstance(sym.node, PlaceholderNode)):
|
||||||
|
raise sem_helpers.BoundNameNotFound(passed_queryset_name_expr.fullname)
|
||||||
|
|
||||||
|
assert isinstance(sym.node, TypeInfo)
|
||||||
|
return sym.node
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_django_manager_info_or_exception(ctx: DynamicClassDefContext) -> Optional[TypeInfo]:
|
||||||
|
api = sem_helpers.get_semanal_api(ctx)
|
||||||
|
|
||||||
|
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
|
||||||
|
if (sym is None
|
||||||
|
or sym.node is None
|
||||||
|
or isinstance(sym.node, PlaceholderNode)):
|
||||||
|
raise sem_helpers.BoundNameNotFound(fullnames.MANAGER_CLASS_FULLNAME)
|
||||||
|
|
||||||
|
assert isinstance(sym.node, TypeInfo)
|
||||||
|
return sym.node
|
||||||
|
|
||||||
|
|
||||||
|
def new_manager_typeinfo(ctx: DynamicClassDefContext, callee_manager_info: TypeInfo) -> TypeInfo:
|
||||||
|
callee_manager_type = Instance(callee_manager_info, [AnyType(TypeOfAny.unannotated)])
|
||||||
|
api = sem_helpers.get_semanal_api(ctx)
|
||||||
|
|
||||||
|
new_manager_class_name = ctx.name
|
||||||
|
new_manager_info = helpers.new_typeinfo(new_manager_class_name,
|
||||||
|
bases=[callee_manager_type], module_name=api.cur_mod_id)
|
||||||
|
new_manager_info.set_line(ctx.call)
|
||||||
|
return new_manager_info
|
||||||
|
|
||||||
|
|
||||||
|
def record_new_manager_info_fullname_into_metadata(ctx: DynamicClassDefContext,
|
||||||
|
new_manager_fullname: str,
|
||||||
|
callee_manager_info: TypeInfo,
|
||||||
|
queryset_info: TypeInfo,
|
||||||
|
django_manager_info: TypeInfo) -> None:
|
||||||
if len(ctx.call.args) > 1:
|
if len(ctx.call.args) > 1:
|
||||||
expr = ctx.call.args[1]
|
expr = ctx.call.args[1]
|
||||||
assert isinstance(expr, StrExpr)
|
assert isinstance(expr, StrExpr)
|
||||||
custom_manager_generated_name = expr.value
|
custom_manager_generated_name = expr.value
|
||||||
else:
|
else:
|
||||||
custom_manager_generated_name = base_manager_info.name + 'From' + derived_queryset_info.name
|
custom_manager_generated_name = callee_manager_info.name + 'From' + queryset_info.name
|
||||||
|
|
||||||
custom_manager_generated_fullname = '.'.join(['django.db.models.manager', custom_manager_generated_name])
|
custom_manager_generated_fullname = 'django.db.models.manager' + '.' + custom_manager_generated_name
|
||||||
if 'from_queryset_managers' not in base_manager_info.metadata:
|
|
||||||
base_manager_info.metadata['from_queryset_managers'] = {}
|
metadata = django_manager_info.metadata.setdefault('from_queryset_managers', {})
|
||||||
base_manager_info.metadata['from_queryset_managers'][custom_manager_generated_fullname] = new_manager_info.fullname
|
metadata[custom_manager_generated_fullname] = new_manager_fullname
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefContext) -> None:
|
||||||
|
semanal_api = sem_helpers.get_semanal_api(ctx)
|
||||||
|
try:
|
||||||
|
callee_manager_info = resolve_callee_manager_info_or_exception(ctx)
|
||||||
|
queryset_info = resolve_passed_queryset_info_or_exception(ctx)
|
||||||
|
django_manager_info = resolve_django_manager_info_or_exception(ctx)
|
||||||
|
except sem_helpers.IncompleteDefnException:
|
||||||
|
if not semanal_api.final_iteration:
|
||||||
|
semanal_api.defer()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
new_manager_info = new_manager_typeinfo(ctx, callee_manager_info)
|
||||||
|
record_new_manager_info_fullname_into_metadata(ctx,
|
||||||
|
new_manager_info.fullname,
|
||||||
|
callee_manager_info,
|
||||||
|
queryset_info,
|
||||||
|
django_manager_info)
|
||||||
|
|
||||||
class_def_context = ClassDefContext(cls=new_manager_info.defn,
|
class_def_context = ClassDefContext(cls=new_manager_info.defn,
|
||||||
reason=ctx.call, api=semanal_api)
|
reason=ctx.call, api=semanal_api)
|
||||||
self_type = Instance(new_manager_info, [])
|
self_type = fill_typevars(new_manager_info)
|
||||||
# we need to copy all methods in MRO before django.db.models.query.QuerySet
|
# self_type = Instance(new_manager_info, [])
|
||||||
for class_mro_info in derived_queryset_info.mro:
|
|
||||||
if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME:
|
try:
|
||||||
break
|
for name, method_node in iter_all_custom_queryset_methods(queryset_info):
|
||||||
for name, sym in class_mro_info.names.items():
|
sem_helpers.copy_method_or_incomplete_defn_exception(class_def_context,
|
||||||
if isinstance(sym.node, FuncDef):
|
|
||||||
helpers.copy_method_to_another_class(class_def_context,
|
|
||||||
self_type,
|
self_type,
|
||||||
new_method_name=name,
|
new_method_name=name,
|
||||||
method_node=sym.node)
|
method_node=method_node)
|
||||||
|
except sem_helpers.IncompleteDefnException:
|
||||||
|
if not semanal_api.final_iteration:
|
||||||
|
semanal_api.defer()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
new_manager_sym = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
|
||||||
|
|
||||||
|
# context=None - forcibly replace old node
|
||||||
|
added = semanal_api.add_symbol_table_node(ctx.name, new_manager_sym, context=None)
|
||||||
|
if added:
|
||||||
|
# replace all references to the old manager Var everywhere
|
||||||
|
for _, module in semanal_api.modules.items():
|
||||||
|
if module.fullname != semanal_api.cur_mod_id:
|
||||||
|
for sym_name, sym in module.names.items():
|
||||||
|
if sym.fullname == new_manager_info.fullname:
|
||||||
|
module.names[sym_name] = new_manager_sym.copy()
|
||||||
|
|
||||||
|
# we need another iteration to process methods
|
||||||
|
if (not added
|
||||||
|
and not semanal_api.final_iteration):
|
||||||
|
semanal_api.defer()
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ from mypy.types import Type as MypyType
|
|||||||
from mypy.types import TypeOfAny
|
from mypy.types import TypeOfAny
|
||||||
|
|
||||||
from mypy_django_plugin.django.context import DjangoContext
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
from mypy_django_plugin.lib import helpers
|
from mypy_django_plugin.lib import helpers, chk_helpers
|
||||||
|
|
||||||
|
|
||||||
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
|
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
|
||||||
field_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx),
|
api = chk_helpers.get_typechecker_api(ctx)
|
||||||
field_fullname)
|
field_info = helpers.lookup_fully_qualified_typeinfo(api, field_fullname)
|
||||||
if field_info is None:
|
if field_info is None:
|
||||||
return AnyType(TypeOfAny.unannotated)
|
return AnyType(TypeOfAny.unannotated)
|
||||||
return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)])
|
return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)])
|
||||||
@@ -32,7 +32,7 @@ def return_proper_field_type_from_get_field(ctx: MethodContext, django_context:
|
|||||||
if model_cls is None:
|
if model_cls is None:
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name')
|
field_name_expr = chk_helpers.get_call_argument_by_name(ctx, 'field_name')
|
||||||
if field_name_expr is None:
|
if field_name_expr is None:
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
from typing import Dict, List, Optional, Type, cast
|
from typing import List, Optional, Type, cast
|
||||||
|
|
||||||
from django.db.models.base import Model
|
from django.db.models.base import Model
|
||||||
from django.db.models.fields import DateField, DateTimeField
|
from django.db.models.fields.related import ForeignKey, OneToOneField
|
||||||
from django.db.models.fields.related import ForeignKey
|
|
||||||
from django.db.models.fields.reverse_related import (
|
from django.db.models.fields.reverse_related import (
|
||||||
ManyToManyRel, ManyToOneRel, OneToOneRel,
|
ManyToManyRel, ManyToOneRel, OneToOneRel,
|
||||||
)
|
)
|
||||||
from mypy.nodes import ARG_STAR2, Argument, Context, FuncDef, TypeInfo, Var
|
from mypy.nodes import ARG_STAR2, Argument, FuncDef, TypeInfo, Var, SymbolTableNode, MDEF, GDEF
|
||||||
from mypy.plugin import ClassDefContext
|
from mypy.plugin import ClassDefContext
|
||||||
from mypy.plugins import common
|
from mypy.plugins import common
|
||||||
from mypy.semanal import SemanticAnalyzer
|
from mypy.semanal import SemanticAnalyzer, dummy_context
|
||||||
from mypy.types import AnyType, Instance
|
from mypy.types import AnyType, Instance
|
||||||
from mypy.types import Type as MypyType
|
from mypy.types import Type as MypyType
|
||||||
from mypy.types import TypeOfAny
|
from mypy.types import TypeOfAny
|
||||||
|
|
||||||
|
from django.db.models.fields import DateField, DateTimeField
|
||||||
from mypy_django_plugin.django.context import DjangoContext
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
from mypy_django_plugin.lib import fullnames, helpers
|
from mypy_django_plugin.lib import fullnames, helpers, sem_helpers
|
||||||
from mypy_django_plugin.transformers import fields
|
from mypy_django_plugin.transformers import fields
|
||||||
from mypy_django_plugin.transformers.fields import get_field_descriptor_types
|
from mypy_django_plugin.transformers.fields import get_field_descriptor_types
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ class ModelClassInitializer:
|
|||||||
def lookup_typeinfo_or_incomplete_defn_error(self, fullname: str) -> TypeInfo:
|
def lookup_typeinfo_or_incomplete_defn_error(self, fullname: str) -> TypeInfo:
|
||||||
info = self.lookup_typeinfo(fullname)
|
info = self.lookup_typeinfo(fullname)
|
||||||
if info is None:
|
if info is None:
|
||||||
raise helpers.IncompleteDefnException(f'No {fullname!r} found')
|
raise sem_helpers.IncompleteDefnException(f'No {fullname!r} found')
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInfo:
|
def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInfo:
|
||||||
@@ -43,26 +43,74 @@ class ModelClassInitializer:
|
|||||||
field_info = self.lookup_typeinfo_or_incomplete_defn_error(fullname)
|
field_info = self.lookup_typeinfo_or_incomplete_defn_error(fullname)
|
||||||
return field_info
|
return field_info
|
||||||
|
|
||||||
def create_new_var(self, name: str, typ: MypyType) -> Var:
|
def model_class_has_attribute_defined(self, name: str, traverse_mro: bool = True) -> bool:
|
||||||
# type=: type of the variable itself
|
if not traverse_mro:
|
||||||
var = Var(name=name, type=typ)
|
sym = self.model_classdef.info.names.get(name)
|
||||||
# var.info: type of the object variable is bound to
|
else:
|
||||||
|
sym = self.model_classdef.info.get(name)
|
||||||
|
return sym is not None
|
||||||
|
|
||||||
|
def resolve_manager_fullname(self, manager_fullname: str) -> str:
|
||||||
|
base_manager_info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME)
|
||||||
|
if (base_manager_info is None
|
||||||
|
or 'from_queryset_managers' not in base_manager_info.metadata):
|
||||||
|
return manager_fullname
|
||||||
|
|
||||||
|
metadata = base_manager_info.metadata['from_queryset_managers']
|
||||||
|
return metadata.get(manager_fullname, manager_fullname)
|
||||||
|
|
||||||
|
def add_new_node_to_model_class(self, name: str, typ: MypyType,
|
||||||
|
force_replace_existing: bool = False) -> None:
|
||||||
|
if not force_replace_existing and name in self.model_classdef.info.names:
|
||||||
|
raise ValueError(f'Member {name!r} already defined at model {self.model_classdef.info.fullname!r}.')
|
||||||
|
|
||||||
|
var = Var(name, type=typ)
|
||||||
|
# TypeInfo of the object variable is bound to
|
||||||
var.info = self.model_classdef.info
|
var.info = self.model_classdef.info
|
||||||
var._fullname = self.model_classdef.info.fullname + '.' + name
|
var._fullname = self.api.qualified_name(name)
|
||||||
var.is_initialized_in_class = True
|
var.is_initialized_in_class = True
|
||||||
var.is_inferred = True
|
|
||||||
return var
|
|
||||||
|
|
||||||
def add_new_node_to_model_class(self, name: str, typ: MypyType) -> None:
|
sym = SymbolTableNode(MDEF, var, plugin_generated=True)
|
||||||
helpers.add_new_sym_for_info(self.model_classdef.info,
|
context = dummy_context()
|
||||||
name=name,
|
if force_replace_existing:
|
||||||
sym_type=typ)
|
context = None
|
||||||
|
self.api.add_symbol_table_node(name, sym, context=context)
|
||||||
|
|
||||||
def add_new_class_for_current_module(self, name: str, bases: List[Instance]) -> TypeInfo:
|
def add_new_class_for_current_module(self, name: str, bases: List[Instance],
|
||||||
current_module = self.api.modules[self.model_classdef.info.module_name]
|
force_replace_existing: bool = False) -> Optional[TypeInfo]:
|
||||||
new_class_info = helpers.add_new_class_for_module(current_module,
|
current_module = self.api.cur_mod_node
|
||||||
name=name, bases=bases)
|
if not force_replace_existing and name in current_module:
|
||||||
return new_class_info
|
raise ValueError(f'Class {name!r} already defined for module {current_module.fullname!r}')
|
||||||
|
|
||||||
|
new_typeinfo = helpers.new_typeinfo(name,
|
||||||
|
bases=bases,
|
||||||
|
module_name=current_module.fullname)
|
||||||
|
# sym = SymbolTableNode(GDEF, new_typeinfo,
|
||||||
|
# plugin_generated=True)
|
||||||
|
# context = dummy_context()
|
||||||
|
# if force_replace_existing:
|
||||||
|
# context = None
|
||||||
|
|
||||||
|
if name in current_module.names:
|
||||||
|
del current_module.names[name]
|
||||||
|
current_module.names[name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)
|
||||||
|
# current_module.defs.append(new_typeinfo.defn)
|
||||||
|
# self.api.cur_mod_node.
|
||||||
|
# self.api.leave_class()
|
||||||
|
# added = self.api.add_symbol_table_node(name, sym, context=context)
|
||||||
|
# self.api.enter_class(self.model_classdef.info)
|
||||||
|
#
|
||||||
|
# self.api.cur_mod_node.defs.append(new_typeinfo.defn)
|
||||||
|
|
||||||
|
# if not added and force_replace_existing:
|
||||||
|
# return None
|
||||||
|
return new_typeinfo
|
||||||
|
|
||||||
|
# current_module = self.api.modules[self.model_classdef.info.module_name]
|
||||||
|
# context =
|
||||||
|
# new_class_info = helpers.add_new_class_for_module(current_module,
|
||||||
|
# name=name, bases=bases)
|
||||||
|
# return new_class_info
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname)
|
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname)
|
||||||
@@ -88,39 +136,71 @@ class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
meta_node = helpers.get_nested_meta_node_for_current_class(self.model_classdef.info)
|
meta_node = sem_helpers.get_nested_meta_node_for_current_class(self.model_classdef.info)
|
||||||
if meta_node is None:
|
if meta_node is None:
|
||||||
return None
|
return None
|
||||||
meta_node.fallback_to_any = True
|
meta_node.fallback_to_any = True
|
||||||
|
|
||||||
|
|
||||||
class AddDefaultPrimaryKey(ModelClassInitializer):
|
class AddDefaultPrimaryKey(ModelClassInitializer):
|
||||||
|
"""
|
||||||
|
Adds default primary key to models which does not define their own.
|
||||||
|
```
|
||||||
|
class User(models.Model):
|
||||||
|
name = models.TextField()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
||||||
auto_field = model_cls._meta.auto_field
|
auto_field = model_cls._meta.auto_field
|
||||||
if auto_field and not self.model_classdef.info.has_readable_member(auto_field.attname):
|
if auto_field is None:
|
||||||
# autogenerated field
|
return
|
||||||
auto_field_fullname = helpers.get_class_fullname(auto_field.__class__)
|
|
||||||
auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_fullname)
|
primary_key_attrname = auto_field.attname
|
||||||
|
if self.model_class_has_attribute_defined(primary_key_attrname):
|
||||||
|
return
|
||||||
|
|
||||||
|
auto_field_class_fullname = helpers.get_class_fullname(auto_field.__class__)
|
||||||
|
auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_class_fullname)
|
||||||
|
|
||||||
set_type, get_type = fields.get_field_descriptor_types(auto_field_info, is_nullable=False)
|
set_type, get_type = fields.get_field_descriptor_types(auto_field_info, is_nullable=False)
|
||||||
self.add_new_node_to_model_class(auto_field.attname, Instance(auto_field_info,
|
self.add_new_node_to_model_class(primary_key_attrname, Instance(auto_field_info,
|
||||||
[set_type, get_type]))
|
[set_type, get_type]))
|
||||||
|
|
||||||
|
|
||||||
class AddRelatedModelsId(ModelClassInitializer):
|
class AddRelatedModelsId(ModelClassInitializer):
|
||||||
|
"""
|
||||||
|
Adds `FIELDNAME_id` attributes to models.
|
||||||
|
```
|
||||||
|
class User(models.Model):
|
||||||
|
pass
|
||||||
|
class Blog(models.Model):
|
||||||
|
user = models.ForeignKey(User)
|
||||||
|
```
|
||||||
|
|
||||||
|
`user_id` will be added to `Blog`.
|
||||||
|
"""
|
||||||
|
|
||||||
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
||||||
for field in model_cls._meta.get_fields():
|
for field in model_cls._meta.get_fields():
|
||||||
if isinstance(field, ForeignKey):
|
if not isinstance(field, (OneToOneField, ForeignKey)):
|
||||||
|
continue
|
||||||
|
related_id_attr_name = field.attname
|
||||||
|
if self.model_class_has_attribute_defined(related_id_attr_name):
|
||||||
|
continue
|
||||||
|
# if self.get_model_class_attr(related_id_attr_name) is not None:
|
||||||
|
# continue
|
||||||
|
|
||||||
related_model_cls = self.django_context.get_field_related_model_cls(field)
|
related_model_cls = self.django_context.get_field_related_model_cls(field)
|
||||||
if related_model_cls is None:
|
if related_model_cls is None:
|
||||||
error_context: Context = self.ctx.cls
|
error_context = self.ctx.cls
|
||||||
field_sym = self.ctx.cls.info.get(field.name)
|
field_sym = self.ctx.cls.info.get(field.name)
|
||||||
if field_sym is not None and field_sym.node is not None:
|
if field_sym is not None and field_sym.node is not None:
|
||||||
error_context = field_sym.node
|
error_context = field_sym.node
|
||||||
self.api.fail(f'Cannot find model {field.related_model!r} '
|
self.api.fail(f'Cannot find model {field.related_model!r} '
|
||||||
f'referenced in field {field.name!r} ',
|
f'referenced in field {field.name!r} ',
|
||||||
ctx=error_context)
|
ctx=error_context)
|
||||||
self.add_new_node_to_model_class(field.attname,
|
self.add_new_node_to_model_class(related_id_attr_name,
|
||||||
AnyType(TypeOfAny.explicit))
|
AnyType(TypeOfAny.explicit))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -130,7 +210,7 @@ class AddRelatedModelsId(ModelClassInitializer):
|
|||||||
rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
|
rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
|
||||||
try:
|
try:
|
||||||
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__)
|
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__)
|
||||||
except helpers.IncompleteDefnException as exc:
|
except sem_helpers.IncompleteDefnException as exc:
|
||||||
if not self.api.final_iteration:
|
if not self.api.final_iteration:
|
||||||
raise exc
|
raise exc
|
||||||
else:
|
else:
|
||||||
@@ -138,7 +218,7 @@ class AddRelatedModelsId(ModelClassInitializer):
|
|||||||
|
|
||||||
is_nullable = self.django_context.get_field_nullability(field, None)
|
is_nullable = self.django_context.get_field_nullability(field, None)
|
||||||
set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
|
set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
|
||||||
self.add_new_node_to_model_class(field.attname,
|
self.add_new_node_to_model_class(related_id_attr_name,
|
||||||
Instance(field_info, [set_type, get_type]))
|
Instance(field_info, [set_type, get_type]))
|
||||||
|
|
||||||
|
|
||||||
@@ -152,25 +232,15 @@ class AddManagers(ModelClassInitializer):
|
|||||||
def is_any_parametrized_manager(self, typ: Instance) -> bool:
|
def is_any_parametrized_manager(self, typ: Instance) -> bool:
|
||||||
return typ.type.fullname in fullnames.MANAGER_CLASSES and isinstance(typ.args[0], AnyType)
|
return typ.type.fullname in fullnames.MANAGER_CLASSES and isinstance(typ.args[0], AnyType)
|
||||||
|
|
||||||
def get_generated_manager_mappings(self, base_manager_fullname: str) -> Dict[str, str]:
|
|
||||||
base_manager_info = self.lookup_typeinfo(base_manager_fullname)
|
|
||||||
if (base_manager_info is None
|
|
||||||
or 'from_queryset_managers' not in base_manager_info.metadata):
|
|
||||||
return {}
|
|
||||||
return base_manager_info.metadata['from_queryset_managers']
|
|
||||||
|
|
||||||
def create_new_model_parametrized_manager(self, name: str, base_manager_info: TypeInfo) -> Instance:
|
def create_new_model_parametrized_manager(self, name: str, base_manager_info: TypeInfo) -> Instance:
|
||||||
bases = []
|
bases = []
|
||||||
for original_base in base_manager_info.bases:
|
for original_base in base_manager_info.bases:
|
||||||
if self.is_any_parametrized_manager(original_base):
|
if self.is_any_parametrized_manager(original_base):
|
||||||
if original_base.type is None:
|
|
||||||
raise helpers.IncompleteDefnException()
|
|
||||||
|
|
||||||
original_base = helpers.reparametrize_instance(original_base,
|
original_base = helpers.reparametrize_instance(original_base,
|
||||||
[Instance(self.model_classdef.info, [])])
|
[Instance(self.model_classdef.info, [])])
|
||||||
bases.append(original_base)
|
bases.append(original_base)
|
||||||
|
|
||||||
new_manager_info = self.add_new_class_for_current_module(name, bases)
|
new_manager_info = self.add_new_class_for_current_module(name, bases, force_replace_existing=True)
|
||||||
# copy fields to a new manager
|
# copy fields to a new manager
|
||||||
new_cls_def_context = ClassDefContext(cls=new_manager_info.defn,
|
new_cls_def_context = ClassDefContext(cls=new_manager_info.defn,
|
||||||
reason=self.ctx.reason,
|
reason=self.ctx.reason,
|
||||||
@@ -178,9 +248,12 @@ class AddManagers(ModelClassInitializer):
|
|||||||
custom_manager_type = Instance(new_manager_info, [Instance(self.model_classdef.info, [])])
|
custom_manager_type = Instance(new_manager_info, [Instance(self.model_classdef.info, [])])
|
||||||
|
|
||||||
for name, sym in base_manager_info.names.items():
|
for name, sym in base_manager_info.names.items():
|
||||||
|
if name in new_manager_info.names:
|
||||||
|
raise ValueError(f'Name {name!r} already exists on newly-created {new_manager_info.fullname!r} class.')
|
||||||
|
|
||||||
# replace self type with new class, if copying method
|
# replace self type with new class, if copying method
|
||||||
if isinstance(sym.node, FuncDef):
|
if isinstance(sym.node, FuncDef):
|
||||||
helpers.copy_method_to_another_class(new_cls_def_context,
|
sem_helpers.copy_method_or_incomplete_defn_exception(new_cls_def_context,
|
||||||
self_type=custom_manager_type,
|
self_type=custom_manager_type,
|
||||||
new_method_name=name,
|
new_method_name=name,
|
||||||
method_node=sym.node)
|
method_node=sym.node)
|
||||||
@@ -192,32 +265,36 @@ class AddManagers(ModelClassInitializer):
|
|||||||
new_var.info = new_manager_info
|
new_var.info = new_manager_info
|
||||||
new_var._fullname = new_manager_info.fullname + '.' + name
|
new_var._fullname = new_manager_info.fullname + '.' + name
|
||||||
new_sym.node = new_var
|
new_sym.node = new_var
|
||||||
|
|
||||||
new_manager_info.names[name] = new_sym
|
new_manager_info.names[name] = new_sym
|
||||||
|
|
||||||
return custom_manager_type
|
return custom_manager_type
|
||||||
|
|
||||||
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
||||||
for manager_name, manager in model_cls._meta.managers_map.items():
|
for manager_name, manager in model_cls._meta.managers_map.items():
|
||||||
manager_class_name = manager.__class__.__name__
|
if self.model_class_has_attribute_defined(manager_name, traverse_mro=False):
|
||||||
manager_fullname = helpers.get_class_fullname(manager.__class__)
|
sym = self.model_classdef.info.names.get(manager_name)
|
||||||
try:
|
assert sym is not None
|
||||||
|
|
||||||
|
if (sym.type is not None
|
||||||
|
and isinstance(sym.type, Instance)
|
||||||
|
and sym.type.type.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME)
|
||||||
|
and not self.has_any_parametrized_manager_as_base(sym.type.type)):
|
||||||
|
# already defined and parametrized properly
|
||||||
|
continue
|
||||||
|
|
||||||
|
if getattr(manager, '_built_with_as_manager', False):
|
||||||
|
# as_manager is not supported yet
|
||||||
|
if not self.model_class_has_attribute_defined(manager_name, traverse_mro=True):
|
||||||
|
self.add_new_node_to_model_class(manager_name, AnyType(TypeOfAny.explicit))
|
||||||
|
continue
|
||||||
|
|
||||||
|
manager_fullname = self.resolve_manager_fullname(helpers.get_class_fullname(manager.__class__))
|
||||||
manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname)
|
manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname)
|
||||||
except helpers.IncompleteDefnException as exc:
|
manager_class_name = manager_fullname.rsplit('.', maxsplit=1)[1]
|
||||||
if not self.api.final_iteration:
|
|
||||||
raise exc
|
|
||||||
else:
|
|
||||||
base_manager_fullname = helpers.get_class_fullname(manager.__class__.__bases__[0])
|
|
||||||
generated_managers = self.get_generated_manager_mappings(base_manager_fullname)
|
|
||||||
if manager_fullname not in generated_managers:
|
|
||||||
# not a generated manager, continue with the loop
|
|
||||||
continue
|
|
||||||
real_manager_fullname = generated_managers[manager_fullname]
|
|
||||||
manager_info = self.lookup_typeinfo(real_manager_fullname) # type: ignore
|
|
||||||
if manager_info is None:
|
|
||||||
continue
|
|
||||||
manager_class_name = real_manager_fullname.rsplit('.', maxsplit=1)[1]
|
|
||||||
|
|
||||||
if manager_name not in self.model_classdef.info.names:
|
if manager_name not in self.model_classdef.info.names:
|
||||||
|
# manager not yet defined, just add models.Manager[ModelName]
|
||||||
manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])])
|
manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])])
|
||||||
self.add_new_node_to_model_class(manager_name, manager_type)
|
self.add_new_node_to_model_class(manager_name, manager_type)
|
||||||
else:
|
else:
|
||||||
@@ -226,21 +303,28 @@ class AddManagers(ModelClassInitializer):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
custom_model_manager_name = manager.model.__name__ + '_' + manager_class_name
|
custom_model_manager_name = manager.model.__name__ + '_' + manager_class_name
|
||||||
try:
|
|
||||||
custom_manager_type = self.create_new_model_parametrized_manager(custom_model_manager_name,
|
custom_manager_type = self.create_new_model_parametrized_manager(custom_model_manager_name,
|
||||||
base_manager_info=manager_info)
|
base_manager_info=manager_info)
|
||||||
except helpers.IncompleteDefnException:
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.add_new_node_to_model_class(manager_name, custom_manager_type)
|
self.add_new_node_to_model_class(manager_name, custom_manager_type,
|
||||||
|
force_replace_existing=True)
|
||||||
|
|
||||||
|
|
||||||
class AddDefaultManagerAttribute(ModelClassInitializer):
|
class AddDefaultManagerAttribute(ModelClassInitializer):
|
||||||
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
||||||
# add _default_manager
|
if self.model_class_has_attribute_defined('_default_manager', traverse_mro=False):
|
||||||
if '_default_manager' not in self.model_classdef.info.names:
|
return
|
||||||
|
if model_cls._meta.default_manager is None:
|
||||||
|
return
|
||||||
|
if getattr(model_cls._meta.default_manager, '_built_with_as_manager', False):
|
||||||
|
self.add_new_node_to_model_class('_default_manager',
|
||||||
|
AnyType(TypeOfAny.explicit))
|
||||||
|
return
|
||||||
|
|
||||||
default_manager_fullname = helpers.get_class_fullname(model_cls._meta.default_manager.__class__)
|
default_manager_fullname = helpers.get_class_fullname(model_cls._meta.default_manager.__class__)
|
||||||
default_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(default_manager_fullname)
|
resolved_default_manager_fullname = self.resolve_manager_fullname(default_manager_fullname)
|
||||||
|
|
||||||
|
default_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(resolved_default_manager_fullname)
|
||||||
default_manager = Instance(default_manager_info, [Instance(self.model_classdef.info, [])])
|
default_manager = Instance(default_manager_info, [Instance(self.model_classdef.info, [])])
|
||||||
self.add_new_node_to_model_class('_default_manager', default_manager)
|
self.add_new_node_to_model_class('_default_manager', default_manager)
|
||||||
|
|
||||||
@@ -249,33 +333,37 @@ class AddRelatedManagers(ModelClassInitializer):
|
|||||||
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
||||||
# add related managers
|
# add related managers
|
||||||
for relation in self.django_context.get_model_relations(model_cls):
|
for relation in self.django_context.get_model_relations(model_cls):
|
||||||
attname = relation.get_accessor_name()
|
related_manager_attr_name = relation.get_accessor_name()
|
||||||
if attname is None:
|
if related_manager_attr_name is None:
|
||||||
# no reverse accessor
|
# no reverse accessor
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if self.model_class_has_attribute_defined(related_manager_attr_name, traverse_mro=False):
|
||||||
|
continue
|
||||||
|
|
||||||
related_model_cls = self.django_context.get_field_related_model_cls(relation)
|
related_model_cls = self.django_context.get_field_related_model_cls(relation)
|
||||||
if related_model_cls is None:
|
if related_model_cls is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls)
|
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls)
|
||||||
except helpers.IncompleteDefnException as exc:
|
except sem_helpers.IncompleteDefnException as exc:
|
||||||
if not self.api.final_iteration:
|
if not self.api.final_iteration:
|
||||||
raise exc
|
raise exc
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(relation, OneToOneRel):
|
if isinstance(relation, OneToOneRel):
|
||||||
self.add_new_node_to_model_class(attname, Instance(related_model_info, []))
|
self.add_new_node_to_model_class(related_manager_attr_name, Instance(related_model_info, []))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(relation, (ManyToOneRel, ManyToManyRel)):
|
if isinstance(relation, (ManyToOneRel, ManyToManyRel)):
|
||||||
try:
|
try:
|
||||||
related_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.RELATED_MANAGER_CLASS) # noqa: E501
|
related_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(
|
||||||
|
fullnames.RELATED_MANAGER_CLASS) # noqa: E501
|
||||||
if 'objects' not in related_model_info.names:
|
if 'objects' not in related_model_info.names:
|
||||||
raise helpers.IncompleteDefnException()
|
raise sem_helpers.IncompleteDefnException()
|
||||||
except helpers.IncompleteDefnException as exc:
|
except sem_helpers.IncompleteDefnException as exc:
|
||||||
if not self.api.final_iteration:
|
if not self.api.final_iteration:
|
||||||
raise exc
|
raise exc
|
||||||
else:
|
else:
|
||||||
@@ -288,14 +376,20 @@ class AddRelatedManagers(ModelClassInitializer):
|
|||||||
if (default_manager_type is None
|
if (default_manager_type is None
|
||||||
or not isinstance(default_manager_type, Instance)
|
or not isinstance(default_manager_type, Instance)
|
||||||
or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME):
|
or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME):
|
||||||
self.add_new_node_to_model_class(attname, parametrized_related_manager_type)
|
self.add_new_node_to_model_class(related_manager_attr_name, parametrized_related_manager_type)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
name = related_model_cls.__name__ + '_' + 'RelatedManager'
|
name = related_model_cls.__name__ + '_' + 'RelatedManager'
|
||||||
bases = [parametrized_related_manager_type, default_manager_type]
|
bases = [parametrized_related_manager_type, default_manager_type]
|
||||||
new_related_manager_info = self.add_new_class_for_current_module(name, bases)
|
new_related_manager_info = self.add_new_class_for_current_module(name, bases,
|
||||||
|
force_replace_existing=True)
|
||||||
|
if new_related_manager_info is None:
|
||||||
|
# wasn't added for some reason, defer
|
||||||
|
if not self.api.final_iteration:
|
||||||
|
self.api.defer()
|
||||||
|
continue
|
||||||
|
|
||||||
self.add_new_node_to_model_class(attname, Instance(new_related_manager_info, []))
|
self.add_new_node_to_model_class(related_manager_attr_name, Instance(new_related_manager_info, []))
|
||||||
|
|
||||||
|
|
||||||
class AddExtraFieldMethods(ModelClassInitializer):
|
class AddExtraFieldMethods(ModelClassInitializer):
|
||||||
@@ -355,6 +449,8 @@ def process_model_class(ctx: ClassDefContext,
|
|||||||
for initializer_cls in initializers:
|
for initializer_cls in initializers:
|
||||||
try:
|
try:
|
||||||
initializer_cls(ctx, django_context).run()
|
initializer_cls(ctx, django_context).run()
|
||||||
except helpers.IncompleteDefnException:
|
except sem_helpers.IncompleteDefnException as exc:
|
||||||
if not ctx.api.final_iteration:
|
if not ctx.api.final_iteration:
|
||||||
ctx.api.defer()
|
ctx.api.defer()
|
||||||
|
continue
|
||||||
|
raise exc
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from mypy.types import Type as MypyType
|
|||||||
from mypy.types import TypeOfAny
|
from mypy.types import TypeOfAny
|
||||||
|
|
||||||
from mypy_django_plugin.django.context import DjangoContext
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
from mypy_django_plugin.lib import fullnames, helpers
|
from mypy_django_plugin.lib import fullnames, helpers, chk_helpers
|
||||||
|
|
||||||
|
|
||||||
def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||||
@@ -35,7 +35,7 @@ def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext)
|
|||||||
fullnames.QUERYSET_CLASS_FULLNAME))):
|
fullnames.QUERYSET_CLASS_FULLNAME))):
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
helpers.check_types_compatible(ctx,
|
chk_helpers.check_types_compatible(ctx,
|
||||||
expected_type=lookup_type,
|
expected_type=lookup_type,
|
||||||
actual_type=provided_type,
|
actual_type=provided_type,
|
||||||
error_message=f'Incompatible type for lookup {lookup_kwarg!r}:')
|
error_message=f'Incompatible type for lookup {lookup_kwarg!r}:')
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from mypy.types import TypeOfAny
|
|||||||
from mypy_django_plugin.django.context import (
|
from mypy_django_plugin.django.context import (
|
||||||
DjangoContext, LookupsAreUnsupported,
|
DjangoContext, LookupsAreUnsupported,
|
||||||
)
|
)
|
||||||
from mypy_django_plugin.lib import fullnames, helpers
|
from mypy_django_plugin.lib import fullnames, helpers, chk_helpers
|
||||||
|
|
||||||
|
|
||||||
def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]:
|
def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]:
|
||||||
@@ -30,7 +30,7 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
|
|||||||
default_return_type = ctx.default_return_type
|
default_return_type = ctx.default_return_type
|
||||||
assert isinstance(default_return_type, Instance)
|
assert isinstance(default_return_type, Instance)
|
||||||
|
|
||||||
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
|
outer_model_info = chk_helpers.get_typechecker_api(ctx).scope.active_class()
|
||||||
if (outer_model_info is None
|
if (outer_model_info is None
|
||||||
or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)):
|
or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)):
|
||||||
return default_return_type
|
return default_return_type
|
||||||
@@ -55,7 +55,7 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext
|
|||||||
return AnyType(TypeOfAny.from_error)
|
return AnyType(TypeOfAny.from_error)
|
||||||
lookup_field = django_context.get_primary_key_field(related_model_cls)
|
lookup_field = django_context.get_primary_key_field(related_model_cls)
|
||||||
|
|
||||||
field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx),
|
field_get_type = django_context.get_field_get_type(chk_helpers.get_typechecker_api(ctx),
|
||||||
lookup_field, method=method)
|
lookup_field, method=method)
|
||||||
return field_get_type
|
return field_get_type
|
||||||
|
|
||||||
@@ -66,7 +66,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
|
|||||||
if field_lookups is None:
|
if field_lookups is None:
|
||||||
return AnyType(TypeOfAny.from_error)
|
return AnyType(TypeOfAny.from_error)
|
||||||
|
|
||||||
typechecker_api = helpers.get_typechecker_api(ctx)
|
typechecker_api = chk_helpers.get_typechecker_api(ctx)
|
||||||
if len(field_lookups) == 0:
|
if len(field_lookups) == 0:
|
||||||
if flat:
|
if flat:
|
||||||
primary_key_field = django_context.get_primary_key_field(model_cls)
|
primary_key_field = django_context.get_primary_key_field(model_cls)
|
||||||
@@ -80,7 +80,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
|
|||||||
column_type = django_context.get_field_get_type(typechecker_api, field,
|
column_type = django_context.get_field_get_type(typechecker_api, field,
|
||||||
method='values_list')
|
method='values_list')
|
||||||
column_types[field.attname] = column_type
|
column_types[field.attname] = column_type
|
||||||
return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
|
return chk_helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
|
||||||
else:
|
else:
|
||||||
# flat=False, named=False, all fields
|
# flat=False, named=False, all fields
|
||||||
field_lookups = []
|
field_lookups = []
|
||||||
@@ -103,9 +103,9 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
|
|||||||
assert len(column_types) == 1
|
assert len(column_types) == 1
|
||||||
row_type = next(iter(column_types.values()))
|
row_type = next(iter(column_types.values()))
|
||||||
elif named:
|
elif named:
|
||||||
row_type = helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
|
row_type = chk_helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
|
||||||
else:
|
else:
|
||||||
row_type = helpers.make_tuple(typechecker_api, list(column_types.values()))
|
row_type = chk_helpers.make_tuple(typechecker_api, list(column_types.values()))
|
||||||
|
|
||||||
return row_type
|
return row_type
|
||||||
|
|
||||||
@@ -123,13 +123,13 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
|
|||||||
if model_cls is None:
|
if model_cls is None:
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
flat_expr = helpers.get_call_argument_by_name(ctx, 'flat')
|
flat_expr = chk_helpers.get_call_argument_by_name(ctx, 'flat')
|
||||||
if flat_expr is not None and isinstance(flat_expr, NameExpr):
|
if flat_expr is not None and isinstance(flat_expr, NameExpr):
|
||||||
flat = helpers.parse_bool(flat_expr)
|
flat = helpers.parse_bool(flat_expr)
|
||||||
else:
|
else:
|
||||||
flat = False
|
flat = False
|
||||||
|
|
||||||
named_expr = helpers.get_call_argument_by_name(ctx, 'named')
|
named_expr = chk_helpers.get_call_argument_by_name(ctx, 'named')
|
||||||
if named_expr is not None and isinstance(named_expr, NameExpr):
|
if named_expr is not None and isinstance(named_expr, NameExpr):
|
||||||
named = helpers.parse_bool(named_expr)
|
named = helpers.parse_bool(named_expr)
|
||||||
else:
|
else:
|
||||||
@@ -188,5 +188,5 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
|
|||||||
|
|
||||||
column_types[field_lookup] = field_lookup_type
|
column_types[field_lookup] = field_lookup_type
|
||||||
|
|
||||||
row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()))
|
row_type = chk_helpers.make_oneoff_typeddict(ctx.api, column_types, set(column_types.keys()))
|
||||||
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
|
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ from mypy.types import Instance
|
|||||||
from mypy.types import Type as MypyType
|
from mypy.types import Type as MypyType
|
||||||
|
|
||||||
from mypy_django_plugin.django.context import DjangoContext
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
from mypy_django_plugin.lib import helpers
|
from mypy_django_plugin.lib import helpers, chk_helpers
|
||||||
|
|
||||||
|
|
||||||
def set_auth_user_model_as_type_for_request_user(ctx: AttributeContext, django_context: DjangoContext) -> MypyType:
|
def set_auth_user_model_as_type_for_request_user(ctx: AttributeContext, django_context: DjangoContext) -> MypyType:
|
||||||
auth_user_model = django_context.settings.AUTH_USER_MODEL
|
auth_user_model = django_context.settings.AUTH_USER_MODEL
|
||||||
model_cls = django_context.apps_registry.get_model(auth_user_model)
|
model_cls = django_context.apps_registry.get_model(auth_user_model)
|
||||||
model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls)
|
model_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), model_cls)
|
||||||
if model_info is None:
|
if model_info is None:
|
||||||
return ctx.default_attr_type
|
return ctx.default_attr_type
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from mypy.types import Type as MypyType
|
|||||||
from mypy.types import TypeOfAny, TypeType
|
from mypy.types import TypeOfAny, TypeType
|
||||||
|
|
||||||
from mypy_django_plugin.django.context import DjangoContext
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
from mypy_django_plugin.lib import helpers
|
from mypy_django_plugin.lib import helpers, chk_helpers
|
||||||
|
|
||||||
|
|
||||||
def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
|
def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
|
||||||
@@ -13,7 +13,7 @@ def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) ->
|
|||||||
model_cls = django_context.apps_registry.get_model(auth_user_model)
|
model_cls = django_context.apps_registry.get_model(auth_user_model)
|
||||||
model_cls_fullname = helpers.get_class_fullname(model_cls)
|
model_cls_fullname = helpers.get_class_fullname(model_cls)
|
||||||
|
|
||||||
model_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx),
|
model_info = helpers.lookup_fully_qualified_typeinfo(chk_helpers.get_typechecker_api(ctx),
|
||||||
model_cls_fullname)
|
model_cls_fullname)
|
||||||
if model_info is None:
|
if model_info is None:
|
||||||
return AnyType(TypeOfAny.unannotated)
|
return AnyType(TypeOfAny.unannotated)
|
||||||
@@ -28,7 +28,7 @@ def get_type_of_settings_attribute(ctx: AttributeContext, django_context: Django
|
|||||||
ctx.api.fail(f"'Settings' object has no attribute {setting_name!r}", ctx.context)
|
ctx.api.fail(f"'Settings' object has no attribute {setting_name!r}", ctx.context)
|
||||||
return ctx.default_attr_type
|
return ctx.default_attr_type
|
||||||
|
|
||||||
typechecker_api = helpers.get_typechecker_api(ctx)
|
typechecker_api = chk_helpers.get_typechecker_api(ctx)
|
||||||
|
|
||||||
# first look for the setting in the project settings file, then global settings
|
# first look for the setting in the project settings file, then global settings
|
||||||
settings_module = typechecker_api.modules.get(django_context.django_settings_module)
|
settings_module = typechecker_api.modules.get(django_context.django_settings_module)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from myapp.models import MyModel
|
from myapp.models import MyModel
|
||||||
reveal_type(MyModel().objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]'
|
reveal_type(MyModel().objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]'
|
||||||
reveal_type(MyModel().objects.get()) # N: Revealed type is 'myapp.models.MyModel*'
|
reveal_type(MyModel().objects.get()) # N: Revealed type is 'myapp.models.MyModel*'
|
||||||
|
reveal_type(MyModel().objects.queryset_method) # N: Revealed type is 'def () -> builtins.str'
|
||||||
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is 'builtins.str'
|
reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is 'builtins.str'
|
||||||
installed_apps:
|
installed_apps:
|
||||||
- myapp
|
- myapp
|
||||||
@@ -179,3 +180,56 @@
|
|||||||
class BaseQuerySet(models.QuerySet):
|
class BaseQuerySet(models.QuerySet):
|
||||||
def base_queryset_method(self, param: Union[int, str]) -> NoReturn:
|
def base_queryset_method(self, param: Union[int, str]) -> NoReturn:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
- case: from_queryset_with_inherited_manager_and_fk_to_auth_contrib
|
||||||
|
disable_cache: true
|
||||||
|
main: |
|
||||||
|
from myapp.base_queryset import BaseQuerySet
|
||||||
|
reveal_type(BaseQuerySet().base_queryset_method) # N: Revealed type is 'def (param: builtins.dict[builtins.str, Union[builtins.int, builtins.str]]) -> Union[builtins.int, builtins.str]'
|
||||||
|
|
||||||
|
from django.contrib.auth.models import Permission
|
||||||
|
reveal_type(Permission().another_models) # N: Revealed type is 'django.db.models.manager.RelatedManager[myapp.models.AnotherModelInProjectWithContribAuthM2M]'
|
||||||
|
|
||||||
|
from myapp.managers import NewManager
|
||||||
|
reveal_type(NewManager()) # N: Revealed type is 'myapp.managers.NewManager'
|
||||||
|
reveal_type(NewManager().base_queryset_method) # N: Revealed type is 'def (param: builtins.dict[builtins.str, Union[builtins.int, builtins.str]]) -> Union[builtins.int, builtins.str]'
|
||||||
|
|
||||||
|
from myapp.models import MyModel
|
||||||
|
reveal_type(MyModel().objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]'
|
||||||
|
reveal_type(MyModel().objects.get()) # N: Revealed type is 'myapp.models.MyModel*'
|
||||||
|
reveal_type(MyModel().objects.base_queryset_method) # N: Revealed type is 'def (param: builtins.dict[builtins.str, Union[builtins.int, builtins.str]]) -> Union[builtins.int, builtins.str]'
|
||||||
|
installed_apps:
|
||||||
|
- myapp
|
||||||
|
- django.contrib.auth
|
||||||
|
files:
|
||||||
|
- path: myapp/__init__.py
|
||||||
|
- path: myapp/models.py
|
||||||
|
content: |
|
||||||
|
from django.db import models
|
||||||
|
from myapp.managers import NewManager
|
||||||
|
from django.contrib.auth.models import Permission
|
||||||
|
|
||||||
|
class MyModel(models.Model):
|
||||||
|
objects = NewManager()
|
||||||
|
|
||||||
|
class AnotherModelInProjectWithContribAuthM2M(models.Model):
|
||||||
|
permissions = models.ForeignKey(
|
||||||
|
Permission,
|
||||||
|
on_delete=models.PROTECT,
|
||||||
|
related_name='another_models'
|
||||||
|
)
|
||||||
|
- path: myapp/managers.py
|
||||||
|
content: |
|
||||||
|
from django.db import models
|
||||||
|
from myapp.base_queryset import BaseQuerySet
|
||||||
|
class ModelQuerySet(BaseQuerySet):
|
||||||
|
pass
|
||||||
|
NewManager = models.Manager.from_queryset(ModelQuerySet)
|
||||||
|
- path: myapp/base_queryset.py
|
||||||
|
content: |
|
||||||
|
from typing import Union, Dict
|
||||||
|
from django.db import models
|
||||||
|
class BaseQuerySet(models.QuerySet):
|
||||||
|
def base_queryset_method(self, param: Dict[str, Union[int, str]]) -> Union[int, str]:
|
||||||
|
return param["hello"]
|
||||||
@@ -307,15 +307,15 @@
|
|||||||
- case: custom_manager_returns_proper_model_types
|
- case: custom_manager_returns_proper_model_types
|
||||||
main: |
|
main: |
|
||||||
from myapp.models import User
|
from myapp.models import User
|
||||||
reveal_type(User.objects) # N: Revealed type is 'myapp.models.User_MyManager2[myapp.models.User]'
|
reveal_type(User.objects) # N: Revealed type is 'myapp.models.User_MyManager[myapp.models.User]'
|
||||||
reveal_type(User.objects.select_related()) # N: Revealed type is 'myapp.models.User_MyManager2[myapp.models.User]'
|
reveal_type(User.objects.select_related()) # N: Revealed type is 'myapp.models.User_MyManager[myapp.models.User]'
|
||||||
reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*'
|
reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*'
|
||||||
reveal_type(User.objects.get_instance()) # N: Revealed type is 'builtins.int'
|
reveal_type(User.objects.get_instance()) # N: Revealed type is 'builtins.int'
|
||||||
reveal_type(User.objects.get_instance_untyped('hello')) # N: Revealed type is 'Any'
|
reveal_type(User.objects.get_instance_untyped('hello')) # N: Revealed type is 'Any'
|
||||||
|
|
||||||
from myapp.models import ChildUser
|
from myapp.models import ChildUser
|
||||||
reveal_type(ChildUser.objects) # N: Revealed type is 'myapp.models.ChildUser_MyManager2[myapp.models.ChildUser]'
|
reveal_type(ChildUser.objects) # N: Revealed type is 'myapp.models.ChildUser_MyManager[myapp.models.ChildUser]'
|
||||||
reveal_type(ChildUser.objects.select_related()) # N: Revealed type is 'myapp.models.ChildUser_MyManager2[myapp.models.ChildUser]'
|
reveal_type(ChildUser.objects.select_related()) # N: Revealed type is 'myapp.models.ChildUser_MyManager[myapp.models.ChildUser]'
|
||||||
reveal_type(ChildUser.objects.get()) # N: Revealed type is 'myapp.models.ChildUser*'
|
reveal_type(ChildUser.objects.get()) # N: Revealed type is 'myapp.models.ChildUser*'
|
||||||
reveal_type(ChildUser.objects.get_instance()) # N: Revealed type is 'builtins.int'
|
reveal_type(ChildUser.objects.get_instance()) # N: Revealed type is 'builtins.int'
|
||||||
reveal_type(ChildUser.objects.get_instance_untyped('hello')) # N: Revealed type is 'Any'
|
reveal_type(ChildUser.objects.get_instance_untyped('hello')) # N: Revealed type is 'Any'
|
||||||
|
|||||||
Reference in New Issue
Block a user