fix tests

This commit is contained in:
Maxim Kurnikov
2020-01-05 08:18:43 +03:00
parent 0a92c89d41
commit e5b61dc499
3 changed files with 33 additions and 35 deletions

View File

@@ -1,5 +1,5 @@
from typing import ( from typing import (
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Union, TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union,
) )
from django.db.models.fields import Field from django.db.models.fields import Field
@@ -8,7 +8,7 @@ from django.db.models.fields.reverse_related import ForeignObjectRel
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.mro import calculate_mro from mypy.mro import calculate_mro
from mypy.nodes import ( from mypy.nodes import (
Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, SymbolTableNode, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTable, SymbolTableNode,
TypeInfo, Var, TypeInfo, Var,
) )
from mypy.semanal import SemanticAnalyzer from mypy.semanal import SemanticAnalyzer
@@ -28,23 +28,32 @@ def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {}) return model_info.metadata.setdefault('django', {})
def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]: def split_symbol_name(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[Tuple[str, str]]:
if '.' not in fullname: if '.' not in fullname:
return None return None
module_file = None module_name = None
parts = fullname.split('.') parts = fullname.split('.')
for i in range(len(parts), 0, -1): for i in range(len(parts), 0, -1):
possible_module_name = '.'.join(parts[:i]) possible_module_name = '.'.join(parts[:i])
if possible_module_name in all_modules: if possible_module_name in all_modules:
module_file = all_modules[possible_module_name] module_name = possible_module_name
break break
if module_file is None: if module_name is None:
return None return None
cls_name = fullname.replace(module_file.fullname, '').lstrip('.') cls_name = fullname.replace(module_name, '').lstrip('.')
sym_table = module_file.names return module_name, cls_name
def lookup_fully_qualified_typeinfo(api: AnyPluginAPI, fullname: str) -> Optional[TypeInfo]:
split = split_symbol_name(fullname, api.modules)
if split is None:
return None
module_name, cls_name = split
sym_table = api.modules[module_name].names # type: Dict[str, SymbolTableNode]
if '.' in cls_name: if '.' in cls_name:
parent_cls_name, _, cls_name = cls_name.rpartition('.') parent_cls_name, _, cls_name = cls_name.rpartition('.')
# nested class # nested class
@@ -55,23 +64,14 @@ def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile])
return None return None
sym_table = sym.node.names sym_table = sym.node.names
return sym_table.get(cls_name) sym = sym_table.get(cls_name)
if (sym is None
or sym.node is None
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]: or not isinstance(sym.node, TypeInfo)):
sym = lookup_fully_qualified_sym(name, all_modules)
if sym is None:
return None return None
return sym.node return sym.node
def lookup_fully_qualified_typeinfo(api: AnyPluginAPI, fullname: str) -> Optional[TypeInfo]:
node = lookup_fully_qualified_generic(fullname, api.modules)
if not isinstance(node, TypeInfo):
return None
return node
def lookup_class_typeinfo(api: AnyPluginAPI, klass: type) -> Optional[TypeInfo]: def lookup_class_typeinfo(api: AnyPluginAPI, klass: type) -> Optional[TypeInfo]:
fullname = get_class_fullname(klass) fullname = get_class_fullname(klass)
field_info = lookup_fully_qualified_typeinfo(api, fullname) field_info = lookup_fully_qualified_typeinfo(api, fullname)

View File

@@ -4,9 +4,9 @@ from mypy.nodes import Argument, FuncDef, TypeInfo, Var
from mypy.plugin import ClassDefContext, DynamicClassDefContext from mypy.plugin import ClassDefContext, DynamicClassDefContext
from mypy.plugins.common import add_method from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzer from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, CallableType, Instance from mypy.types import AnyType, CallableType, Instance, PlaceholderType
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy.types import TypeOfAny from mypy.types import TypeOfAny, get_proper_type
class IncompleteDefnException(Exception): class IncompleteDefnException(Exception):
@@ -54,8 +54,9 @@ def analyze_callable_signature(api: SemanticAnalyzer, method_node: FuncDef) -> S
for arg_name, arg_type, original_argument in zip(method_type.arg_names[1:], for arg_name, arg_type, original_argument in zip(method_type.arg_names[1:],
method_type.arg_types[1:], method_type.arg_types[1:],
method_node.arguments[1:]): method_node.arguments[1:]):
analyzed_arg_type = api.anal_type(arg_type) analyzed_arg_type = api.anal_type(get_proper_type(arg_type), allow_placeholder=True)
if analyzed_arg_type is None: assert analyzed_arg_type is not None
if isinstance(analyzed_arg_type, PlaceholderType):
unbound = True unbound = True
var = Var(name=original_argument.variable.name, var = Var(name=original_argument.variable.name,
@@ -63,14 +64,15 @@ def analyze_callable_signature(api: SemanticAnalyzer, method_node: FuncDef) -> S
var.set_line(original_argument.variable) var.set_line(original_argument.variable)
argument = Argument(variable=var, argument = Argument(variable=var,
type_annotation=arg_type, type_annotation=analyzed_arg_type,
initializer=original_argument.initializer, initializer=original_argument.initializer,
kind=original_argument.kind) kind=original_argument.kind)
argument.set_line(original_argument) argument.set_line(original_argument)
arguments.append(argument) arguments.append(argument)
analyzed_ret_type = api.anal_type(method_type.ret_type) analyzed_ret_type = api.anal_type(get_proper_type(method_type.ret_type), allow_placeholder=True)
if analyzed_ret_type is None: assert analyzed_ret_type is not None
if isinstance(analyzed_ret_type, PlaceholderType):
unbound = True unbound = True
return SignatureTuple(arguments, analyzed_ret_type, unbound) return SignatureTuple(arguments, analyzed_ret_type, unbound)

View File

@@ -60,15 +60,11 @@ def resolve_passed_queryset_info_or_exception(ctx: DynamicClassDefContext) -> Ty
def resolve_django_manager_info_or_exception(ctx: DynamicClassDefContext) -> TypeInfo: def resolve_django_manager_info_or_exception(ctx: DynamicClassDefContext) -> TypeInfo:
api = sem_helpers.get_semanal_api(ctx) api = sem_helpers.get_semanal_api(ctx)
info = helpers.lookup_fully_qualified_typeinfo(api, fullnames.MANAGER_CLASS_FULLNAME)
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME) if info is None:
if (sym is None
or sym.node is None
or isinstance(sym.node, PlaceholderNode)):
raise sem_helpers.BoundNameNotFound(fullnames.MANAGER_CLASS_FULLNAME) raise sem_helpers.BoundNameNotFound(fullnames.MANAGER_CLASS_FULLNAME)
assert isinstance(sym.node, TypeInfo) return info
return sym.node
def new_manager_typeinfo(ctx: DynamicClassDefContext, callee_manager_info: TypeInfo) -> TypeInfo: def new_manager_typeinfo(ctx: DynamicClassDefContext, callee_manager_info: TypeInfo) -> TypeInfo: