diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index f743e07..e4ca1ac 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -1,5 +1,5 @@ from typing import ( - TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Union, + TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union, ) from django.db.models.fields import Field @@ -8,7 +8,7 @@ from django.db.models.fields.reverse_related import ForeignObjectRel from mypy.checker import TypeChecker from mypy.mro import calculate_mro from mypy.nodes import ( - Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, SymbolTableNode, + Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTable, SymbolTableNode, TypeInfo, Var, ) from mypy.semanal import SemanticAnalyzer @@ -28,23 +28,32 @@ def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]: return model_info.metadata.setdefault('django', {}) -def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]: +def split_symbol_name(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[Tuple[str, str]]: if '.' not in fullname: return None - module_file = None + module_name = None parts = fullname.split('.') for i in range(len(parts), 0, -1): possible_module_name = '.'.join(parts[:i]) if possible_module_name in all_modules: - module_file = all_modules[possible_module_name] + module_name = possible_module_name break - if module_file is None: + if module_name is None: return None - cls_name = fullname.replace(module_file.fullname, '').lstrip('.') - sym_table = module_file.names + cls_name = fullname.replace(module_name, '').lstrip('.') + return module_name, cls_name + + +def lookup_fully_qualified_typeinfo(api: AnyPluginAPI, fullname: str) -> Optional[TypeInfo]: + split = split_symbol_name(fullname, api.modules) + if split is None: + return None + module_name, cls_name = split + + sym_table = api.modules[module_name].names # type: Dict[str, SymbolTableNode] if '.' in cls_name: parent_cls_name, _, cls_name = cls_name.rpartition('.') # nested class @@ -55,23 +64,14 @@ def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) return None sym_table = sym.node.names - return sym_table.get(cls_name) - - -def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]: - sym = lookup_fully_qualified_sym(name, all_modules) - if sym is None: + sym = sym_table.get(cls_name) + if (sym is None + or sym.node is None + or not isinstance(sym.node, TypeInfo)): return None return sym.node -def lookup_fully_qualified_typeinfo(api: AnyPluginAPI, fullname: str) -> Optional[TypeInfo]: - node = lookup_fully_qualified_generic(fullname, api.modules) - if not isinstance(node, TypeInfo): - return None - return node - - def lookup_class_typeinfo(api: AnyPluginAPI, klass: type) -> Optional[TypeInfo]: fullname = get_class_fullname(klass) field_info = lookup_fully_qualified_typeinfo(api, fullname) diff --git a/mypy_django_plugin/lib/sem_helpers.py b/mypy_django_plugin/lib/sem_helpers.py index 33db42e..139d604 100644 --- a/mypy_django_plugin/lib/sem_helpers.py +++ b/mypy_django_plugin/lib/sem_helpers.py @@ -4,9 +4,9 @@ from mypy.nodes import Argument, FuncDef, TypeInfo, Var from mypy.plugin import ClassDefContext, DynamicClassDefContext from mypy.plugins.common import add_method from mypy.semanal import SemanticAnalyzer -from mypy.types import AnyType, CallableType, Instance +from mypy.types import AnyType, CallableType, Instance, PlaceholderType from mypy.types import Type as MypyType -from mypy.types import TypeOfAny +from mypy.types import TypeOfAny, get_proper_type class IncompleteDefnException(Exception): @@ -54,8 +54,9 @@ def analyze_callable_signature(api: SemanticAnalyzer, method_node: FuncDef) -> S for arg_name, arg_type, original_argument in zip(method_type.arg_names[1:], method_type.arg_types[1:], method_node.arguments[1:]): - analyzed_arg_type = api.anal_type(arg_type) - if analyzed_arg_type is None: + analyzed_arg_type = api.anal_type(get_proper_type(arg_type), allow_placeholder=True) + assert analyzed_arg_type is not None + if isinstance(analyzed_arg_type, PlaceholderType): unbound = True var = Var(name=original_argument.variable.name, @@ -63,14 +64,15 @@ def analyze_callable_signature(api: SemanticAnalyzer, method_node: FuncDef) -> S var.set_line(original_argument.variable) argument = Argument(variable=var, - type_annotation=arg_type, + type_annotation=analyzed_arg_type, initializer=original_argument.initializer, kind=original_argument.kind) argument.set_line(original_argument) arguments.append(argument) - analyzed_ret_type = api.anal_type(method_type.ret_type) - if analyzed_ret_type is None: + analyzed_ret_type = api.anal_type(get_proper_type(method_type.ret_type), allow_placeholder=True) + assert analyzed_ret_type is not None + if isinstance(analyzed_ret_type, PlaceholderType): unbound = True return SignatureTuple(arguments, analyzed_ret_type, unbound) diff --git a/mypy_django_plugin/transformers/managers.py b/mypy_django_plugin/transformers/managers.py index 747e273..78d2230 100644 --- a/mypy_django_plugin/transformers/managers.py +++ b/mypy_django_plugin/transformers/managers.py @@ -60,15 +60,11 @@ def resolve_passed_queryset_info_or_exception(ctx: DynamicClassDefContext) -> Ty def resolve_django_manager_info_or_exception(ctx: DynamicClassDefContext) -> TypeInfo: api = sem_helpers.get_semanal_api(ctx) - - sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME) - if (sym is None - or sym.node is None - or isinstance(sym.node, PlaceholderNode)): + info = helpers.lookup_fully_qualified_typeinfo(api, fullnames.MANAGER_CLASS_FULLNAME) + if info is None: raise sem_helpers.BoundNameNotFound(fullnames.MANAGER_CLASS_FULLNAME) - assert isinstance(sym.node, TypeInfo) - return sym.node + return info def new_manager_typeinfo(ctx: DynamicClassDefContext, callee_manager_info: TypeInfo) -> TypeInfo: