from typing import Dict, List, Optional, Set, Union from mypy.checker import TypeChecker from mypy.nodes import Expression, MypyFile, NameExpr, SymbolNode, TypeInfo, Var from mypy.plugin import FunctionContext, MethodContext from mypy.types import AnyType, Instance, NoneTyp, Type as MypyType, TypeOfAny, UnionType class IncompleteDefnException(Exception): pass def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]: if '.' not in name: return None module, cls_name = name.rsplit('.', 1) module_file = all_modules.get(module) if module_file is None: return None sym = module_file.names.get(cls_name) if sym is None: return None return sym.node def lookup_fully_qualified_typeinfo(api: TypeChecker, fullname: str) -> Optional[TypeInfo]: node = lookup_fully_qualified_generic(fullname, api.modules) if not isinstance(node, TypeInfo): return None return node def lookup_class_typeinfo(api: TypeChecker, klass: type) -> TypeInfo: fullname = get_class_fullname(klass) field_info = lookup_fully_qualified_typeinfo(api, fullname) return field_info def reparametrize_instance(instance: Instance, new_args: List[MypyType]) -> Instance: return Instance(instance.type, args=new_args, line=instance.line, column=instance.column) def get_class_fullname(klass: type) -> str: return klass.__module__ + '.' + klass.__qualname__ def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]: """ Return the expression for the specific argument. This helper should only be used with non-star arguments. """ if name not in ctx.callee_arg_names: return None idx = ctx.callee_arg_names.index(name) args = ctx.args[idx] if len(args) != 1: # Either an error or no value passed. return None return args[0] def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]: """Return the type for the specific argument. This helper should only be used with non-star arguments. """ if name not in ctx.callee_arg_names: return None idx = ctx.callee_arg_names.index(name) arg_types = ctx.arg_types[idx] if len(arg_types) != 1: # Either an error or no value passed. return None return arg_types[0] def make_optional(typ: MypyType) -> MypyType: return UnionType.make_union([typ, NoneTyp()]) def parse_bool(expr: Expression) -> Optional[bool]: if isinstance(expr, NameExpr): if expr.fullname == 'builtins.True': return True if expr.fullname == 'builtins.False': return False return None def has_any_of_bases(info: TypeInfo, bases: Set[str]) -> bool: for base_fullname in bases: if info.has_base(base_fullname): return True return False def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType: node = type_info.get(private_field_name).node if isinstance(node, Var): descriptor_type = node.type if is_nullable: descriptor_type = make_optional(descriptor_type) return descriptor_type return AnyType(TypeOfAny.unannotated) def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]: metaclass_sym = info.names.get('Meta') if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo): return metaclass_sym.node return None def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType: if isinstance(typ, UnionType): converted_items = [] for item in typ.items: converted_items.append(convert_any_to_type(item, referred_to_type)) return UnionType.make_union(converted_items, line=typ.line, column=typ.column) if isinstance(typ, Instance): args = [] for default_arg in typ.args: if isinstance(default_arg, AnyType): args.append(referred_to_type) else: args.append(default_arg) return reparametrize_instance(typ, args) if isinstance(typ, AnyType): return referred_to_type return typ