mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-08 13:04:47 +08:00
407 lines
15 KiB
Python
407 lines
15 KiB
Python
from abc import abstractmethod
|
|
from typing import (
|
|
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union,
|
|
cast)
|
|
|
|
from django.db.models.fields.related import RelatedField
|
|
from django.db.models.fields.reverse_related import ForeignObjectRel
|
|
from mypy.checker import TypeChecker
|
|
from mypy.mro import calculate_mro
|
|
from mypy.nodes import (
|
|
Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTable, SymbolTableNode,
|
|
TypeInfo, Var,
|
|
CallExpr, Context, PlaceholderNode, FuncDef, FakeInfo)
|
|
from mypy.plugin import DynamicClassDefContext, ClassDefContext
|
|
from mypy.plugins.common import add_method
|
|
from mypy.semanal import SemanticAnalyzer
|
|
from mypy.types import AnyType, Instance, NoneTyp, TypeType
|
|
from mypy.types import Type as MypyType
|
|
from mypy.types import TypeOfAny, UnionType
|
|
from mypy.typetraverser import TypeTraverserVisitor
|
|
|
|
from django.db.models.fields import Field
|
|
from mypy_django_plugin.lib import fullnames
|
|
from mypy_django_plugin.lib.sem_helpers import prepare_unannotated_method_signature, analyze_callable_signature
|
|
from mypy_django_plugin.transformers2 import new_helpers
|
|
|
|
if TYPE_CHECKING:
|
|
from mypy_django_plugin.django.context import DjangoContext
|
|
from mypy_django_plugin.main import NewSemanalDjangoPlugin
|
|
|
|
AnyPluginAPI = Union[TypeChecker, SemanticAnalyzer]
|
|
|
|
|
|
class DjangoPluginCallback:
|
|
django_context: 'DjangoContext'
|
|
|
|
def __init__(self, plugin: 'NewSemanalDjangoPlugin') -> None:
|
|
self.plugin = plugin
|
|
self.django_context = plugin.django_context
|
|
|
|
# def lookup_fully_qualified(self, fullname: str) -> Optional[SymbolTableNode]:
|
|
# return self.plugin.lookup_fully_qualified(fullname)
|
|
|
|
|
|
class SemanalPluginCallback(DjangoPluginCallback):
|
|
semanal_api: SemanticAnalyzer
|
|
|
|
def build_defer_error_message(self, message: str) -> str:
|
|
return f'{self.__class__.__name__}: {message}'
|
|
|
|
def defer_till_next_iteration(self, deferral_context: Optional[Context] = None,
|
|
*,
|
|
reason: Optional[str] = None) -> bool:
|
|
""" Returns False if cannot be deferred. """
|
|
if self.semanal_api.final_iteration:
|
|
return False
|
|
self.semanal_api.defer(deferral_context)
|
|
print(f'LOG: defer: {self.build_defer_error_message(reason)}')
|
|
return True
|
|
|
|
def lookup_typeinfo_or_defer(self, fullname: str, *,
|
|
deferral_context: Optional[Context] = None,
|
|
reason_for_defer: Optional[str] = None) -> Optional[TypeInfo]:
|
|
sym = self.plugin.lookup_fully_qualified(fullname)
|
|
if sym is None or sym.node is None or isinstance(sym.node, PlaceholderNode):
|
|
deferral_context = deferral_context or self.semanal_api.cur_mod_node
|
|
reason = reason_for_defer or f'{fullname!r} is not available for lookup'
|
|
if not self.defer_till_next_iteration(deferral_context, reason=reason):
|
|
raise new_helpers.TypeInfoNotFound(fullname)
|
|
return None
|
|
|
|
if not isinstance(sym.node, TypeInfo):
|
|
raise ValueError(f'{fullname!r} does not correspond to TypeInfo')
|
|
|
|
return sym.node
|
|
|
|
def new_typeinfo(self, name: str, bases: List[Instance]) -> TypeInfo:
|
|
class_def = ClassDef(name, Block([]))
|
|
class_def.fullname = self.semanal_api.qualified_name(name)
|
|
|
|
info = TypeInfo(SymbolTable(), class_def, self.semanal_api.cur_mod_id)
|
|
info.bases = bases
|
|
calculate_mro(info)
|
|
info.metaclass_type = info.calculate_metaclass_type()
|
|
|
|
class_def.info = info
|
|
return info
|
|
|
|
# def add_symbol_table_node_or_defer(self, name: str, sym: SymbolTableNode) -> bool:
|
|
# return self.semanal_api.add_symbol_table_node(name, sym,
|
|
# context=self.semanal_api.cur_mod_node)
|
|
|
|
def add_method_from_signature(self,
|
|
signature_node: FuncDef,
|
|
new_method_name: str,
|
|
new_self_type: Instance,
|
|
class_defn: ClassDef) -> bool:
|
|
if signature_node.type is None:
|
|
if self.defer_till_next_iteration(reason=signature_node.fullname):
|
|
return False
|
|
|
|
arguments, return_type = prepare_unannotated_method_signature(signature_node)
|
|
ctx = ClassDefContext(class_defn, signature_node, self.semanal_api)
|
|
add_method(ctx,
|
|
new_method_name,
|
|
self_type=new_self_type,
|
|
args=arguments,
|
|
return_type=return_type)
|
|
return True
|
|
|
|
# add imported objects from method signature to the current module, if not present
|
|
source_symbols = self.semanal_api.modules[signature_node.info.module_name].names
|
|
currently_imported_symbols = self.semanal_api.cur_mod_node.names
|
|
|
|
def import_symbol_from_source(name: str) -> None:
|
|
if name in source_symbols['__builtins__'].node.names:
|
|
return
|
|
sym = source_symbols[name].copy()
|
|
self.semanal_api.add_imported_symbol(name, sym, context=self.semanal_api.cur_mod_node)
|
|
|
|
class UnimportedTypesVisitor(TypeTraverserVisitor):
|
|
def visit_union_type(self, t: UnionType) -> None:
|
|
super().visit_union_type(t)
|
|
union_sym = currently_imported_symbols.get('Union')
|
|
if union_sym is None:
|
|
# TODO: check if it's exactly typing.Union
|
|
import_symbol_from_source('Union')
|
|
|
|
def visit_type_type(self, t: TypeType) -> None:
|
|
super().visit_type_type(t)
|
|
type_sym = currently_imported_symbols.get('Union')
|
|
if type_sym is None:
|
|
# TODO: check if it's exactly typing.Type
|
|
import_symbol_from_source('Type')
|
|
|
|
def visit_instance(self, t: Instance) -> None:
|
|
super().visit_instance(t)
|
|
if isinstance(t.type, FakeInfo):
|
|
return
|
|
type_name = t.type.name
|
|
sym = currently_imported_symbols.get(type_name)
|
|
if sym is None:
|
|
# TODO: check if it's exactly typing.Type
|
|
import_symbol_from_source(type_name)
|
|
|
|
signature_node.type.accept(UnimportedTypesVisitor())
|
|
|
|
# # copy global SymbolTableNode objects from original class to the current node, if not present
|
|
# original_module = semanal_api.modules[method_node.info.module_name]
|
|
# for name, sym in original_module.names.items():
|
|
# if (not sym.plugin_generated
|
|
# and name not in semanal_api.cur_mod_node.names):
|
|
# semanal_api.add_imported_symbol(name, sym, context=semanal_api.cur_mod_node)
|
|
|
|
arguments, analyzed_return_type, unbound = analyze_callable_signature(self.semanal_api, signature_node)
|
|
if unbound:
|
|
raise new_helpers.IncompleteDefnError(f'Signature of method {signature_node.fullname!r} is not ready')
|
|
|
|
assert len(arguments) + 1 == len(signature_node.arguments)
|
|
assert analyzed_return_type is not None
|
|
|
|
ctx = ClassDefContext(class_defn, signature_node, self.semanal_api)
|
|
add_method(ctx,
|
|
new_method_name,
|
|
self_type=new_self_type,
|
|
args=arguments,
|
|
return_type=analyzed_return_type)
|
|
return True
|
|
|
|
|
|
class DynamicClassPluginCallback(SemanalPluginCallback):
|
|
class_name: str
|
|
call_expr: CallExpr
|
|
|
|
def __call__(self, ctx: DynamicClassDefContext) -> None:
|
|
self.class_name = ctx.name
|
|
self.call_expr = ctx.call
|
|
self.semanal_api = cast(SemanticAnalyzer, ctx.api)
|
|
self.create_new_dynamic_class()
|
|
|
|
def get_callee(self) -> MemberExpr:
|
|
callee = self.call_expr.callee
|
|
assert isinstance(callee, MemberExpr)
|
|
return callee
|
|
|
|
def lookup_same_module_or_defer(self, name: str, *,
|
|
deferral_context: Optional[Context] = None) -> Optional[SymbolTableNode]:
|
|
sym = self.semanal_api.lookup_qualified(name, self.call_expr)
|
|
if sym is None or sym.node is None or isinstance(sym.node, PlaceholderNode):
|
|
deferral_context = deferral_context or self.call_expr
|
|
if not self.defer_till_next_iteration(deferral_context,
|
|
reason=f'{self.semanal_api.cur_mod_id}.{name} does not exist'):
|
|
raise new_helpers.NameNotFound(name)
|
|
return None
|
|
return sym
|
|
|
|
@abstractmethod
|
|
def create_new_dynamic_class(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
class ClassDefPluginCallback(SemanalPluginCallback):
|
|
reason: Expression
|
|
class_defn: ClassDef
|
|
|
|
def __call__(self, ctx: ClassDefContext) -> None:
|
|
self.reason = ctx.reason
|
|
self.class_defn = ctx.cls
|
|
self.semanal_api = cast(SemanticAnalyzer, ctx.api)
|
|
self.modify_class_defn()
|
|
|
|
@abstractmethod
|
|
def modify_class_defn(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
|
|
return model_info.metadata.setdefault('django', {})
|
|
|
|
|
|
def split_symbol_name(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[Tuple[str, str]]:
|
|
if '.' not in fullname:
|
|
return None
|
|
module_name = None
|
|
parts = fullname.split('.')
|
|
for i in range(len(parts), 0, -1):
|
|
possible_module_name = '.'.join(parts[:i])
|
|
if possible_module_name in all_modules:
|
|
module_name = possible_module_name
|
|
break
|
|
if module_name is None:
|
|
return None
|
|
|
|
symbol_name = fullname.replace(module_name, '').lstrip('.')
|
|
return module_name, symbol_name
|
|
|
|
|
|
def lookup_fully_qualified_typeinfo(api: AnyPluginAPI, fullname: str) -> Optional[TypeInfo]:
|
|
split = split_symbol_name(fullname, api.modules)
|
|
if split is None:
|
|
return None
|
|
module_name, cls_name = split
|
|
|
|
sym_table = api.modules[module_name].names # type: Dict[str, SymbolTableNode]
|
|
if '.' in cls_name:
|
|
parent_cls_name, _, cls_name = cls_name.rpartition('.')
|
|
# nested class
|
|
for parent_cls_name in parent_cls_name.split('.'):
|
|
sym = sym_table.get(parent_cls_name)
|
|
if (sym is None or sym.node is None
|
|
or not isinstance(sym.node, TypeInfo)):
|
|
return None
|
|
sym_table = sym.node.names
|
|
|
|
sym = sym_table.get(cls_name)
|
|
if (sym is None
|
|
or sym.node is None
|
|
or not isinstance(sym.node, TypeInfo)):
|
|
return None
|
|
return sym.node
|
|
|
|
|
|
def lookup_class_typeinfo(api: AnyPluginAPI, klass: type) -> Optional[TypeInfo]:
|
|
fullname = get_class_fullname(klass)
|
|
field_info = lookup_fully_qualified_typeinfo(api, fullname)
|
|
return field_info
|
|
|
|
|
|
def reparametrize_instance(instance: Instance, new_args: List[MypyType]) -> Instance:
|
|
return Instance(instance.type, args=new_args,
|
|
line=instance.line, column=instance.column)
|
|
|
|
|
|
def get_class_fullname(klass: type) -> str:
|
|
return klass.__module__ + '.' + klass.__qualname__
|
|
|
|
|
|
def make_optional(typ: MypyType) -> MypyType:
|
|
return UnionType.make_union([typ, NoneTyp()])
|
|
|
|
|
|
def parse_bool(expr: Expression) -> Optional[bool]:
|
|
if isinstance(expr, NameExpr):
|
|
if expr.fullname == 'builtins.True':
|
|
return True
|
|
if expr.fullname == 'builtins.False':
|
|
return False
|
|
return None
|
|
|
|
|
|
def has_any_of_bases(info: TypeInfo, bases: Iterable[str]) -> bool:
|
|
for base_fullname in bases:
|
|
if info.has_base(base_fullname):
|
|
return True
|
|
return False
|
|
|
|
|
|
def iter_bases(info: TypeInfo) -> Iterator[Instance]:
|
|
for base in info.bases:
|
|
yield base
|
|
yield from iter_bases(base.type)
|
|
|
|
|
|
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType:
|
|
""" Return declared type of type_info's private_field_name (used for private Field attributes)"""
|
|
sym = type_info.get(private_field_name)
|
|
if sym is None:
|
|
return AnyType(TypeOfAny.explicit)
|
|
|
|
node = sym.node
|
|
if isinstance(node, Var):
|
|
descriptor_type = node.type
|
|
if descriptor_type is None:
|
|
return AnyType(TypeOfAny.explicit)
|
|
|
|
if is_nullable:
|
|
descriptor_type = make_optional(descriptor_type)
|
|
return descriptor_type
|
|
return AnyType(TypeOfAny.explicit)
|
|
|
|
|
|
def get_field_lookup_exact_type(api: AnyPluginAPI, field: Field) -> MypyType:
|
|
if isinstance(field, (RelatedField, ForeignObjectRel)):
|
|
lookup_type_class = field.related_model
|
|
rel_model_info = lookup_class_typeinfo(api, lookup_type_class)
|
|
if rel_model_info is None:
|
|
return AnyType(TypeOfAny.from_error)
|
|
return make_optional(Instance(rel_model_info, []))
|
|
|
|
field_info = lookup_class_typeinfo(api, field.__class__)
|
|
if field_info is None:
|
|
return AnyType(TypeOfAny.explicit)
|
|
return get_private_descriptor_type(field_info, '_pyi_lookup_exact_type',
|
|
is_nullable=field.null)
|
|
|
|
|
|
def get_current_module(api: AnyPluginAPI) -> MypyFile:
|
|
if isinstance(api, SemanticAnalyzer):
|
|
return api.cur_mod_node
|
|
|
|
current_module = None
|
|
for item in reversed(api.scope.stack):
|
|
if isinstance(item, MypyFile):
|
|
current_module = item
|
|
break
|
|
assert current_module is not None
|
|
return current_module
|
|
|
|
|
|
def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
|
|
if isinstance(typ, UnionType):
|
|
converted_items = []
|
|
for item in typ.items:
|
|
converted_items.append(convert_any_to_type(item, referred_to_type))
|
|
return UnionType.make_union(converted_items,
|
|
line=typ.line, column=typ.column)
|
|
if isinstance(typ, Instance):
|
|
args = []
|
|
for default_arg in typ.args:
|
|
if isinstance(default_arg, AnyType):
|
|
args.append(referred_to_type)
|
|
else:
|
|
args.append(default_arg)
|
|
return reparametrize_instance(typ, args)
|
|
|
|
if isinstance(typ, AnyType):
|
|
return referred_to_type
|
|
|
|
return typ
|
|
|
|
|
|
def resolve_string_attribute_value(attr_expr: Expression, django_context: 'DjangoContext') -> Optional[str]:
|
|
if isinstance(attr_expr, StrExpr):
|
|
return attr_expr.value
|
|
|
|
# support extracting from settings, in general case it's unresolvable yet
|
|
if isinstance(attr_expr, MemberExpr):
|
|
member_name = attr_expr.name
|
|
if isinstance(attr_expr.expr, NameExpr) and attr_expr.expr.fullname == 'django.conf.settings':
|
|
if hasattr(django_context.settings, member_name):
|
|
return getattr(django_context.settings, member_name)
|
|
return None
|
|
|
|
|
|
def is_subclass_of_model(info: TypeInfo, django_context: 'DjangoContext') -> bool:
|
|
return (info.fullname in django_context.all_registered_model_class_fullnames
|
|
or info.has_base(fullnames.MODEL_CLASS_FULLNAME))
|
|
|
|
|
|
def new_typeinfo(name: str,
|
|
*,
|
|
bases: List[Instance],
|
|
module_name: str) -> TypeInfo:
|
|
"""
|
|
Construct new TypeInfo instance. Cannot be used for nested classes.
|
|
"""
|
|
class_def = ClassDef(name, Block([]))
|
|
class_def.fullname = module_name + '.' + name
|
|
|
|
info = TypeInfo(SymbolTable(), class_def, module_name)
|
|
info.bases = bases
|
|
calculate_mro(info)
|
|
info.metaclass_type = info.calculate_metaclass_type()
|
|
|
|
class_def.info = info
|
|
return info
|