fix tests

This commit is contained in:
Maxim Kurnikov
2020-01-05 08:18:43 +03:00
parent 0a92c89d41
commit e5b61dc499
3 changed files with 33 additions and 35 deletions

View File

@@ -1,5 +1,5 @@
from typing import (
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Union,
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union,
)
from django.db.models.fields import Field
@@ -8,7 +8,7 @@ from django.db.models.fields.reverse_related import ForeignObjectRel
from mypy.checker import TypeChecker
from mypy.mro import calculate_mro
from mypy.nodes import (
Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, SymbolTableNode,
Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTable, SymbolTableNode,
TypeInfo, Var,
)
from mypy.semanal import SemanticAnalyzer
@@ -28,23 +28,32 @@ def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {})
def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]:
def split_symbol_name(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[Tuple[str, str]]:
if '.' not in fullname:
return None
module_file = None
module_name = None
parts = fullname.split('.')
for i in range(len(parts), 0, -1):
possible_module_name = '.'.join(parts[:i])
if possible_module_name in all_modules:
module_file = all_modules[possible_module_name]
module_name = possible_module_name
break
if module_file is None:
if module_name is None:
return None
cls_name = fullname.replace(module_file.fullname, '').lstrip('.')
sym_table = module_file.names
cls_name = fullname.replace(module_name, '').lstrip('.')
return module_name, cls_name
def lookup_fully_qualified_typeinfo(api: AnyPluginAPI, fullname: str) -> Optional[TypeInfo]:
split = split_symbol_name(fullname, api.modules)
if split is None:
return None
module_name, cls_name = split
sym_table = api.modules[module_name].names # type: Dict[str, SymbolTableNode]
if '.' in cls_name:
parent_cls_name, _, cls_name = cls_name.rpartition('.')
# nested class
@@ -55,23 +64,14 @@ def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile])
return None
sym_table = sym.node.names
return sym_table.get(cls_name)
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]:
sym = lookup_fully_qualified_sym(name, all_modules)
if sym is None:
sym = sym_table.get(cls_name)
if (sym is None
or sym.node is None
or not isinstance(sym.node, TypeInfo)):
return None
return sym.node
def lookup_fully_qualified_typeinfo(api: AnyPluginAPI, fullname: str) -> Optional[TypeInfo]:
node = lookup_fully_qualified_generic(fullname, api.modules)
if not isinstance(node, TypeInfo):
return None
return node
def lookup_class_typeinfo(api: AnyPluginAPI, klass: type) -> Optional[TypeInfo]:
fullname = get_class_fullname(klass)
field_info = lookup_fully_qualified_typeinfo(api, fullname)