Files
django-stubs/mypy_django_plugin/lib/helpers.py

259 lines
9.3 KiB
Python

from collections import OrderedDict
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union, cast
from mypy import checker
from mypy.checker import TypeChecker
from mypy.mro import calculate_mro
from mypy.nodes import (Block, ClassDef, Expression, GDEF, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode,
SymbolTable, SymbolTableNode, TypeInfo, Var)
from mypy.plugin import (
AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext,
)
from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType
from mypy_django_plugin.lib import fullnames
if TYPE_CHECKING:
from mypy_django_plugin.django.context import DjangoContext
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {})
class IncompleteDefnException(Exception):
pass
def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]:
if '.' not in fullname:
return None
module, cls_name = fullname.rsplit('.', 1)
module_file = all_modules.get(module)
if module_file is None:
return None
sym = module_file.names.get(cls_name)
if sym is None:
return None
return sym
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]:
sym = lookup_fully_qualified_sym(name, all_modules)
if sym is None:
return None
return sym.node
def lookup_fully_qualified_typeinfo(api: TypeChecker, fullname: str) -> Optional[TypeInfo]:
node = lookup_fully_qualified_generic(fullname, api.modules)
if not isinstance(node, TypeInfo):
return None
return node
def lookup_class_typeinfo(api: TypeChecker, klass: type) -> Optional[TypeInfo]:
fullname = get_class_fullname(klass)
field_info = lookup_fully_qualified_typeinfo(api, fullname)
return field_info
def reparametrize_instance(instance: Instance, new_args: List[MypyType]) -> Instance:
return Instance(instance.type, args=new_args,
line=instance.line, column=instance.column)
def get_class_fullname(klass: type) -> str:
return klass.__module__ + '.' + klass.__qualname__
def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]:
"""
Return the expression for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
args = ctx.args[idx]
if len(args) != 1:
# Either an error or no value passed.
return None
return args[0]
def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]:
"""Return the type for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
arg_types = ctx.arg_types[idx]
if len(arg_types) != 1:
# Either an error or no value passed.
return None
return arg_types[0]
def make_optional(typ: MypyType) -> MypyType:
return UnionType.make_union([typ, NoneTyp()])
def parse_bool(expr: Expression) -> Optional[bool]:
if isinstance(expr, NameExpr):
if expr.fullname == 'builtins.True':
return True
if expr.fullname == 'builtins.False':
return False
return None
def has_any_of_bases(info: TypeInfo, bases: Set[str]) -> bool:
for base_fullname in bases:
if info.has_base(base_fullname):
return True
return False
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType:
""" Return declared type of type_info's private_field_name (used for private Field attributes)"""
sym = type_info.get(private_field_name)
if sym is None:
return AnyType(TypeOfAny.unannotated)
node = sym.node
if isinstance(node, Var):
descriptor_type = node.type
if descriptor_type is None:
return AnyType(TypeOfAny.unannotated)
if is_nullable:
descriptor_type = make_optional(descriptor_type)
return descriptor_type
return AnyType(TypeOfAny.unannotated)
def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]:
metaclass_sym = info.names.get('Meta')
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
return metaclass_sym.node
return None
def add_new_class_for_module(module: MypyFile, name: str, bases: List[Instance],
fields: 'OrderedDict[str, MypyType]') -> TypeInfo:
new_class_unique_name = checker.gen_unique_name(name, module.names)
# make new class expression
classdef = ClassDef(new_class_unique_name, Block([]))
classdef.fullname = module.fullname() + '.' + new_class_unique_name
# make new TypeInfo
new_typeinfo = TypeInfo(SymbolTable(), classdef, module.fullname())
new_typeinfo.bases = bases
calculate_mro(new_typeinfo)
new_typeinfo.calculate_metaclass_type()
def add_field_to_new_typeinfo(var: Var, is_initialized_in_class: bool = False,
is_property: bool = False) -> None:
var.info = new_typeinfo
var.is_initialized_in_class = is_initialized_in_class
var.is_property = is_property
var._fullname = new_typeinfo.fullname() + '.' + var.name()
new_typeinfo.names[var.name()] = SymbolTableNode(MDEF, var)
# add fields
var_items = [Var(item, typ) for item, typ in fields.items()]
for var_item in var_items:
add_field_to_new_typeinfo(var_item, is_property=True)
classdef.info = new_typeinfo
module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)
return new_typeinfo
def get_current_module(api: TypeChecker) -> MypyFile:
current_module = None
for item in reversed(api.scope.stack):
if isinstance(item, MypyFile):
current_module = item
break
assert current_module is not None
return current_module
def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType:
current_module = get_current_module(api)
namedtuple_info = add_new_class_for_module(current_module, name,
bases=[api.named_generic_type('typing.NamedTuple', [])],
fields=fields)
return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, []))
def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType:
# fallback for tuples is any builtins.tuple instance
fallback = api.named_generic_type('builtins.tuple',
[AnyType(TypeOfAny.special_form)])
return TupleType(fields, fallback=fallback)
def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
if isinstance(typ, UnionType):
converted_items = []
for item in typ.items:
converted_items.append(convert_any_to_type(item, referred_to_type))
return UnionType.make_union(converted_items,
line=typ.line, column=typ.column)
if isinstance(typ, Instance):
args = []
for default_arg in typ.args:
if isinstance(default_arg, AnyType):
args.append(referred_to_type)
else:
args.append(default_arg)
return reparametrize_instance(typ, args)
if isinstance(typ, AnyType):
return referred_to_type
return typ
def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyType]',
required_keys: Set[str]) -> TypedDictType:
object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
return typed_dict_type
def resolve_string_attribute_value(attr_expr: Expression, ctx: Union[FunctionContext, MethodContext],
django_context: 'DjangoContext') -> Optional[str]:
if isinstance(attr_expr, StrExpr):
return attr_expr.value
# support extracting from settings, in general case it's unresolvable yet
if isinstance(attr_expr, MemberExpr):
member_name = attr_expr.name
if isinstance(attr_expr.expr, NameExpr) and attr_expr.expr.fullname == 'django.conf.settings':
if hasattr(django_context.settings, member_name):
return getattr(django_context.settings, member_name)
ctx.api.fail(f'Expression of type {type(attr_expr).__name__!r} is not supported', ctx.context)
return None
def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker:
if not isinstance(ctx.api, TypeChecker):
raise ValueError('Not a TypeChecker')
return cast(TypeChecker, ctx.api)
def get_all_model_mixins(api: TypeChecker) -> Set[str]:
basemodel_info = lookup_fully_qualified_typeinfo(api, fullnames.MODEL_CLASS_FULLNAME)
if basemodel_info is None:
return set()
return set(get_django_metadata(basemodel_info).get('model_mixins', dict).keys())