mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-08 13:04:47 +08:00
219 lines
7.2 KiB
Python
219 lines
7.2 KiB
Python
from typing import (
|
|
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Union,
|
|
)
|
|
|
|
from django.db.models.fields import Field
|
|
from django.db.models.fields.related import RelatedField
|
|
from django.db.models.fields.reverse_related import ForeignObjectRel
|
|
from mypy.checker import TypeChecker
|
|
from mypy.mro import calculate_mro
|
|
from mypy.nodes import (
|
|
Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, SymbolTableNode,
|
|
TypeInfo, Var,
|
|
)
|
|
from mypy.semanal import SemanticAnalyzer
|
|
from mypy.types import AnyType, Instance, NoneTyp
|
|
from mypy.types import Type as MypyType
|
|
from mypy.types import TypeOfAny, UnionType
|
|
|
|
from mypy_django_plugin.lib import fullnames
|
|
|
|
if TYPE_CHECKING:
|
|
from mypy_django_plugin.django.context import DjangoContext
|
|
|
|
AnyPluginAPI = Union[TypeChecker, SemanticAnalyzer]
|
|
|
|
|
|
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
|
|
return model_info.metadata.setdefault('django', {})
|
|
|
|
|
|
def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]:
|
|
if '.' not in fullname:
|
|
return None
|
|
|
|
module_file = None
|
|
parts = fullname.split('.')
|
|
for i in range(len(parts), 0, -1):
|
|
possible_module_name = '.'.join(parts[:i])
|
|
if possible_module_name in all_modules:
|
|
module_file = all_modules[possible_module_name]
|
|
break
|
|
|
|
if module_file is None:
|
|
return None
|
|
|
|
cls_name = fullname.replace(module_file.fullname, '').lstrip('.')
|
|
sym_table = module_file.names
|
|
if '.' in cls_name:
|
|
parent_cls_name, _, cls_name = cls_name.rpartition('.')
|
|
# nested class
|
|
for parent_cls_name in parent_cls_name.split('.'):
|
|
sym = sym_table.get(parent_cls_name)
|
|
if (sym is None or sym.node is None
|
|
or not isinstance(sym.node, TypeInfo)):
|
|
return None
|
|
sym_table = sym.node.names
|
|
|
|
return sym_table.get(cls_name)
|
|
|
|
|
|
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]:
|
|
sym = lookup_fully_qualified_sym(name, all_modules)
|
|
if sym is None:
|
|
return None
|
|
return sym.node
|
|
|
|
|
|
def lookup_fully_qualified_typeinfo(api: AnyPluginAPI, fullname: str) -> Optional[TypeInfo]:
|
|
node = lookup_fully_qualified_generic(fullname, api.modules)
|
|
if not isinstance(node, TypeInfo):
|
|
return None
|
|
return node
|
|
|
|
|
|
def lookup_class_typeinfo(api: AnyPluginAPI, klass: type) -> Optional[TypeInfo]:
|
|
fullname = get_class_fullname(klass)
|
|
field_info = lookup_fully_qualified_typeinfo(api, fullname)
|
|
return field_info
|
|
|
|
|
|
def reparametrize_instance(instance: Instance, new_args: List[MypyType]) -> Instance:
|
|
return Instance(instance.type, args=new_args,
|
|
line=instance.line, column=instance.column)
|
|
|
|
|
|
def get_class_fullname(klass: type) -> str:
|
|
return klass.__module__ + '.' + klass.__qualname__
|
|
|
|
|
|
def make_optional(typ: MypyType) -> MypyType:
|
|
return UnionType.make_union([typ, NoneTyp()])
|
|
|
|
|
|
def parse_bool(expr: Expression) -> Optional[bool]:
|
|
if isinstance(expr, NameExpr):
|
|
if expr.fullname == 'builtins.True':
|
|
return True
|
|
if expr.fullname == 'builtins.False':
|
|
return False
|
|
return None
|
|
|
|
|
|
def has_any_of_bases(info: TypeInfo, bases: Iterable[str]) -> bool:
|
|
for base_fullname in bases:
|
|
if info.has_base(base_fullname):
|
|
return True
|
|
return False
|
|
|
|
|
|
def iter_bases(info: TypeInfo) -> Iterator[Instance]:
|
|
for base in info.bases:
|
|
yield base
|
|
yield from iter_bases(base.type)
|
|
|
|
|
|
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType:
|
|
""" Return declared type of type_info's private_field_name (used for private Field attributes)"""
|
|
sym = type_info.get(private_field_name)
|
|
if sym is None:
|
|
return AnyType(TypeOfAny.explicit)
|
|
|
|
node = sym.node
|
|
if isinstance(node, Var):
|
|
descriptor_type = node.type
|
|
if descriptor_type is None:
|
|
return AnyType(TypeOfAny.explicit)
|
|
|
|
if is_nullable:
|
|
descriptor_type = make_optional(descriptor_type)
|
|
return descriptor_type
|
|
return AnyType(TypeOfAny.explicit)
|
|
|
|
|
|
def get_field_lookup_exact_type(api: AnyPluginAPI, field: Field) -> MypyType:
|
|
if isinstance(field, (RelatedField, ForeignObjectRel)):
|
|
lookup_type_class = field.related_model
|
|
rel_model_info = lookup_class_typeinfo(api, lookup_type_class)
|
|
if rel_model_info is None:
|
|
return AnyType(TypeOfAny.from_error)
|
|
return make_optional(Instance(rel_model_info, []))
|
|
|
|
field_info = lookup_class_typeinfo(api, field.__class__)
|
|
if field_info is None:
|
|
return AnyType(TypeOfAny.explicit)
|
|
return get_private_descriptor_type(field_info, '_pyi_lookup_exact_type',
|
|
is_nullable=field.null)
|
|
|
|
|
|
def get_current_module(api: AnyPluginAPI) -> MypyFile:
|
|
if isinstance(api, SemanticAnalyzer):
|
|
return api.cur_mod_node
|
|
|
|
current_module = None
|
|
for item in reversed(api.scope.stack):
|
|
if isinstance(item, MypyFile):
|
|
current_module = item
|
|
break
|
|
assert current_module is not None
|
|
return current_module
|
|
|
|
|
|
def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
|
|
if isinstance(typ, UnionType):
|
|
converted_items = []
|
|
for item in typ.items:
|
|
converted_items.append(convert_any_to_type(item, referred_to_type))
|
|
return UnionType.make_union(converted_items,
|
|
line=typ.line, column=typ.column)
|
|
if isinstance(typ, Instance):
|
|
args = []
|
|
for default_arg in typ.args:
|
|
if isinstance(default_arg, AnyType):
|
|
args.append(referred_to_type)
|
|
else:
|
|
args.append(default_arg)
|
|
return reparametrize_instance(typ, args)
|
|
|
|
if isinstance(typ, AnyType):
|
|
return referred_to_type
|
|
|
|
return typ
|
|
|
|
|
|
def resolve_string_attribute_value(attr_expr: Expression, django_context: 'DjangoContext') -> Optional[str]:
|
|
if isinstance(attr_expr, StrExpr):
|
|
return attr_expr.value
|
|
|
|
# support extracting from settings, in general case it's unresolvable yet
|
|
if isinstance(attr_expr, MemberExpr):
|
|
member_name = attr_expr.name
|
|
if isinstance(attr_expr.expr, NameExpr) and attr_expr.expr.fullname == 'django.conf.settings':
|
|
if hasattr(django_context.settings, member_name):
|
|
return getattr(django_context.settings, member_name)
|
|
return None
|
|
|
|
|
|
def is_subclass_of_model(info: TypeInfo, django_context: 'DjangoContext') -> bool:
|
|
return (info.fullname in django_context.all_registered_model_class_fullnames
|
|
or info.has_base(fullnames.MODEL_CLASS_FULLNAME))
|
|
|
|
|
|
def new_typeinfo(name: str,
|
|
*,
|
|
bases: List[Instance],
|
|
module_name: str) -> TypeInfo:
|
|
"""
|
|
Construct new TypeInfo instance. Cannot be used for nested classes.
|
|
"""
|
|
class_def = ClassDef(name, Block([]))
|
|
class_def.fullname = module_name + '.' + name
|
|
|
|
info = TypeInfo(SymbolTable(), class_def, module_name)
|
|
info.bases = bases
|
|
calculate_mro(info)
|
|
info.metaclass_type = info.calculate_metaclass_type()
|
|
|
|
class_def.info = info
|
|
return info
|