Files
django-stubs/mypy_django_plugin/lib/helpers.py
Sigurd Ljødal 830d74b493 Add support for inline from_queryset in model classes (#1045)
* Add support for inline from_queryset in model classes

This adds support for calling <Manager>.from_queryset(<QuerySet>)()
inline in models, for example like this:

    class MyModel(models.Model):
        objects = MyManager.from_queryset(MyQuerySet)()

This is done by inspecting the class body in the transform_class_hook

* Fix missing methods on copied manager

* Add test and other minor tweaks

* Always create manager at module level

When the manager is added at the class level, which happened when it was
created inline in the model body, it's not possible to retrieve the
manager again based on fullname. That lead to problems with inheritance
and the default manager.
2022-07-13 10:04:44 +03:00

449 lines
16 KiB
Python

from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
from django.db.models.fields import Field
from django.db.models.fields.related import RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel
from mypy import checker
from mypy.checker import TypeChecker
from mypy.mro import calculate_mro
from mypy.nodes import (
GDEF,
MDEF,
Argument,
Block,
ClassDef,
Expression,
FuncDef,
MemberExpr,
MypyFile,
NameExpr,
StrExpr,
SymbolNode,
SymbolTable,
SymbolTableNode,
TypeInfo,
Var,
)
from mypy.plugin import (
AttributeContext,
CheckerPluginInterface,
ClassDefContext,
DynamicClassDefContext,
FunctionContext,
MethodContext,
SemanticAnalyzerPluginInterface,
)
from mypy.plugins.common import add_method_to_class
from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType
from mypy.types import Type as MypyType
from mypy.types import TypedDictType, TypeOfAny, UnboundType, UnionType
from mypy_django_plugin.lib import fullnames
from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME
if TYPE_CHECKING:
from mypy_django_plugin.django.context import DjangoContext
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault("django", {})
class IncompleteDefnException(Exception):
pass
def is_toml(filename: str) -> bool:
return filename.lower().endswith(".toml")
def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]:
if "." not in fullname:
return None
if "[" in fullname and "]" in fullname:
# We sometimes generate fake fullnames like a.b.C[x.y.Z] to provide a better representation to users
# Make sure that we handle lookups of those types of names correctly if the part inside [] contains "."
bracket_start = fullname.index("[")
fullname_without_bracket = fullname[:bracket_start]
module, cls_name = fullname_without_bracket.rsplit(".", 1)
cls_name += fullname[bracket_start:]
else:
module, cls_name = fullname.rsplit(".", 1)
module_file = all_modules.get(module)
if module_file is None:
return None
sym = module_file.names.get(cls_name)
if sym is None:
return None
return sym
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]:
sym = lookup_fully_qualified_sym(name, all_modules)
if sym is None:
return None
return sym.node
def lookup_fully_qualified_typeinfo(api: Union[TypeChecker, SemanticAnalyzer], fullname: str) -> Optional[TypeInfo]:
node = lookup_fully_qualified_generic(fullname, api.modules)
if not isinstance(node, TypeInfo):
return None
return node
def lookup_class_typeinfo(api: TypeChecker, klass: type) -> Optional[TypeInfo]:
fullname = get_class_fullname(klass)
field_info = lookup_fully_qualified_typeinfo(api, fullname)
return field_info
def reparametrize_instance(instance: Instance, new_args: List[MypyType]) -> Instance:
return Instance(instance.type, args=new_args, line=instance.line, column=instance.column)
def get_class_fullname(klass: type) -> str:
return klass.__module__ + "." + klass.__qualname__
def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]:
"""
Return the expression for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
args = ctx.args[idx]
if len(args) != 1:
# Either an error or no value passed.
return None
return args[0]
def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]:
"""Return the type for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
arg_types = ctx.arg_types[idx]
if len(arg_types) != 1:
# Either an error or no value passed.
return None
return arg_types[0]
def make_optional(typ: MypyType) -> MypyType:
return UnionType.make_union([typ, NoneTyp()])
def parse_bool(expr: Expression) -> Optional[bool]:
if isinstance(expr, NameExpr):
if expr.fullname == "builtins.True":
return True
if expr.fullname == "builtins.False":
return False
return None
def has_any_of_bases(info: TypeInfo, bases: Iterable[str]) -> bool:
for base_fullname in bases:
if info.has_base(base_fullname):
return True
return False
def iter_bases(info: TypeInfo) -> Iterator[Instance]:
for base in info.bases:
yield base
yield from iter_bases(base.type)
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType:
"""Return declared type of type_info's private_field_name (used for private Field attributes)"""
sym = type_info.get(private_field_name)
if sym is None:
return AnyType(TypeOfAny.explicit)
node = sym.node
if isinstance(node, Var):
descriptor_type = node.type
if descriptor_type is None:
return AnyType(TypeOfAny.explicit)
if is_nullable:
descriptor_type = make_optional(descriptor_type)
return descriptor_type
return AnyType(TypeOfAny.explicit)
def get_field_lookup_exact_type(api: TypeChecker, field: Field) -> MypyType:
if isinstance(field, (RelatedField, ForeignObjectRel)):
lookup_type_class = field.related_model
rel_model_info = lookup_class_typeinfo(api, lookup_type_class)
if rel_model_info is None:
return AnyType(TypeOfAny.from_error)
return make_optional(Instance(rel_model_info, []))
field_info = lookup_class_typeinfo(api, field.__class__)
if field_info is None:
return AnyType(TypeOfAny.explicit)
return get_private_descriptor_type(field_info, "_pyi_lookup_exact_type", is_nullable=field.null)
def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]:
metaclass_sym = info.names.get("Meta")
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
return metaclass_sym.node
return None
def is_annotated_model_fullname(model_cls_fullname: str) -> bool:
return model_cls_fullname.startswith(WITH_ANNOTATIONS_FULLNAME + "[")
def create_type_info(name: str, module: str, bases: List[Instance]) -> TypeInfo:
# make new class expression
classdef = ClassDef(name, Block([]))
classdef.fullname = module + "." + name
# make new TypeInfo
new_typeinfo = TypeInfo(SymbolTable(), classdef, module)
new_typeinfo.bases = bases
calculate_mro(new_typeinfo)
new_typeinfo.calculate_metaclass_type()
classdef.info = new_typeinfo
return new_typeinfo
def add_new_class_for_module(
module: MypyFile,
name: str,
bases: List[Instance],
fields: Optional[Dict[str, MypyType]] = None,
no_serialize: bool = False,
) -> TypeInfo:
new_class_unique_name = checker.gen_unique_name(name, module.names)
new_typeinfo = create_type_info(new_class_unique_name, module.fullname, bases)
# add fields
if fields:
for field_name, field_type in fields.items():
var = Var(field_name, type=field_type)
var.info = new_typeinfo
var._fullname = new_typeinfo.fullname + "." + field_name
new_typeinfo.names[field_name] = SymbolTableNode(
MDEF, var, plugin_generated=True, no_serialize=no_serialize
)
module.names[new_class_unique_name] = SymbolTableNode(
GDEF, new_typeinfo, plugin_generated=True, no_serialize=no_serialize
)
return new_typeinfo
def get_current_module(api: TypeChecker) -> MypyFile:
current_module = None
for item in reversed(api.scope.stack):
if isinstance(item, MypyFile):
current_module = item
break
assert current_module is not None
return current_module
def make_oneoff_named_tuple(
api: TypeChecker, name: str, fields: "OrderedDict[str, MypyType]", extra_bases: Optional[List[Instance]] = None
) -> TupleType:
current_module = get_current_module(api)
if extra_bases is None:
extra_bases = []
namedtuple_info = add_new_class_for_module(
current_module, name, bases=[api.named_generic_type("typing.NamedTuple", [])] + extra_bases, fields=fields
)
return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, []))
def make_tuple(api: "TypeChecker", fields: List[MypyType]) -> TupleType:
# fallback for tuples is any builtins.tuple instance
fallback = api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.special_form)])
return TupleType(fields, fallback=fallback)
def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
if isinstance(typ, UnionType):
converted_items = []
for item in typ.items:
converted_items.append(convert_any_to_type(item, referred_to_type))
return UnionType.make_union(converted_items, line=typ.line, column=typ.column)
if isinstance(typ, Instance):
args = []
for default_arg in typ.args:
if isinstance(default_arg, AnyType):
args.append(referred_to_type)
else:
args.append(default_arg)
return reparametrize_instance(typ, args)
if isinstance(typ, AnyType):
return referred_to_type
return typ
def make_typeddict(
api: CheckerPluginInterface, fields: "OrderedDict[str, MypyType]", required_keys: Set[str]
) -> TypedDictType:
object_type = api.named_generic_type("mypy_extensions._TypedDict", [])
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
return typed_dict_type
def resolve_string_attribute_value(attr_expr: Expression, django_context: "DjangoContext") -> Optional[str]:
if isinstance(attr_expr, StrExpr):
return attr_expr.value
# support extracting from settings, in general case it's unresolvable yet
if isinstance(attr_expr, MemberExpr):
member_name = attr_expr.name
if isinstance(attr_expr.expr, NameExpr) and attr_expr.expr.fullname == "django.conf.settings":
if hasattr(django_context.settings, member_name):
return getattr(django_context.settings, member_name)
return None
def get_semanal_api(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> SemanticAnalyzer:
if not isinstance(ctx.api, SemanticAnalyzer):
raise ValueError("Not a SemanticAnalyzer")
return ctx.api
def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker:
if not isinstance(ctx.api, TypeChecker):
raise ValueError("Not a TypeChecker")
return ctx.api
def is_model_subclass_info(info: TypeInfo, django_context: "DjangoContext") -> bool:
return info.fullname in django_context.all_registered_model_class_fullnames or info.has_base(
fullnames.MODEL_CLASS_FULLNAME
)
def check_types_compatible(
ctx: Union[FunctionContext, MethodContext], *, expected_type: MypyType, actual_type: MypyType, error_message: str
) -> None:
api = get_typechecker_api(ctx)
api.check_subtype(actual_type, expected_type, ctx.context, error_message, "got", "expected")
def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType, no_serialize: bool = False) -> None:
# type=: type of the variable itself
var = Var(name=name, type=sym_type)
# var.info: type of the object variable is bound to
var.info = info
var._fullname = info.fullname + "." + name
var.is_initialized_in_class = True
var.is_inferred = True
info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True, no_serialize=no_serialize)
def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument], MypyType]:
prepared_arguments = []
try:
arguments = method_node.arguments[1:]
except AttributeError:
arguments = []
for argument in arguments:
argument.type_annotation = AnyType(TypeOfAny.unannotated)
prepared_arguments.append(argument)
return_type = AnyType(TypeOfAny.unannotated)
return prepared_arguments, return_type
def bind_or_analyze_type(t: MypyType, api: SemanticAnalyzer, module_name: Optional[str] = None) -> Optional[MypyType]:
"""Analyze a type. If an unbound type, try to look it up in the given module name.
That should hopefully give a bound type."""
if isinstance(t, UnboundType) and module_name is not None:
node = api.lookup_fully_qualified_or_none(module_name + "." + t.name)
if node is not None and node.type is not None:
return node.type
return api.anal_type(t)
def copy_method_to_another_class(
ctx: ClassDefContext,
self_type: Instance,
new_method_name: str,
method_node: FuncDef,
return_type: Optional[MypyType] = None,
original_module_name: Optional[str] = None,
) -> bool:
semanal_api = get_semanal_api(ctx)
if method_node.type is None:
arguments, return_type = build_unannotated_method_args(method_node)
add_method_to_class(
semanal_api, ctx.cls, new_method_name, args=arguments, return_type=return_type, self_type=self_type
)
return True
method_type = method_node.type
if not isinstance(method_type, CallableType):
if not semanal_api.final_iteration:
semanal_api.defer()
return False
if return_type is None:
return_type = bind_or_analyze_type(method_type.ret_type, semanal_api, original_module_name)
if return_type is None:
return False
# We build the arguments from the method signature (`CallableType`), because if we were to
# use the arguments from the method node (`FuncDef.arguments`) we're not compatible with
# a method loaded from cache. As mypy doesn't serialize `FuncDef.arguments` when caching
arguments = []
# Note that the first argument is excluded, as that's `self`
for pos, (arg_type, arg_kind, arg_name) in enumerate(
zip(method_type.arg_types[1:], method_type.arg_kinds[1:], method_type.arg_names[1:]),
start=1,
):
bound_arg_type = bind_or_analyze_type(arg_type, semanal_api, original_module_name)
if bound_arg_type is None:
return False
if arg_name is None and hasattr(method_node, "arguments"):
arg_name = method_node.arguments[pos].variable.name
arguments.append(
Argument(
# Positional only arguments can have name as `None`, if we can't find a name, we just invent one..
variable=Var(name=arg_name if arg_name is not None else str(pos), type=arg_type),
type_annotation=bound_arg_type,
initializer=None,
kind=arg_kind,
pos_only=arg_name is None,
)
)
add_method_to_class(
semanal_api, ctx.cls, new_method_name, args=arguments, return_type=return_type, self_type=self_type
)
return True
def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None:
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
get_django_metadata(sym.node)["manager_bases"][fullname] = 1