diff --git a/jedi/evaluate/base_context.py b/jedi/evaluate/base_context.py index 3d1c441c..b2e80327 100644 --- a/jedi/evaluate/base_context.py +++ b/jedi/evaluate/base_context.py @@ -14,9 +14,32 @@ from jedi.parser_utils import clean_scope_docstring, get_doc_with_call_signature from jedi.common import BaseContextSet, BaseContext from jedi.evaluate.helpers import SimpleGetItemNotFound, execute_evaluated from jedi.evaluate.utils import safe_property +from jedi.evaluate.cache import evaluator_as_method_param_cache + + +def _is_same_class(class1, class2): + if class1 == class2: + return True + + try: + comp_func = class1.is_same_class + except AttributeError: + try: + comp_func = class2.is_same_class + except AttributeError: + return False + else: + return comp_func(class1) + else: + return comp_func(class2) class HelperContextMixin: + @classmethod + @evaluator_as_method_param_cache() + def create_cached(cls, *args, **kwargs): + return cls(*args, **kwargs) + def execute_evaluated(self, *value_list): return execute_evaluated(self, *value_list) @@ -37,6 +60,13 @@ class HelperContextMixin: return f.filter_name(filters) return f.find(filters, attribute_lookup=not search_global) + def is_sub_class_of(self, class_context): + from jedi.evaluate.context.klass import py__mro__ + for cls in py__mro__(self): + if _is_same_class(cls, class_context): + return True + return False + class Context(HelperContextMixin, BaseContext): """ @@ -274,6 +304,12 @@ class ContextSet(BaseContextSet): def get_item(self, *args, **kwargs): return ContextSet.from_sets(_getitem(c, *args, **kwargs) for c in self._set) + def is_sub_class_of(self, class_context): + for c in self._set: + if c.is_sub_class_of(class_context): + return True + return False + NO_CONTEXTS = ContextSet() diff --git a/jedi/evaluate/context/function.py b/jedi/evaluate/context/function.py index bdfb1781..bfb485fc 100644 --- a/jedi/evaluate/context/function.py +++ b/jedi/evaluate/context/function.py @@ -333,11 +333,10 @@ def signature_matches(function_context, arguments): function_context.parent_context, param_node.annotation ) - return has_same_class( - argument.infer().py__class__(), - _type_vars_to_classes(annotation_result), + return any( + argument.infer().py__class__().is_sub_class_of(c) + for c in _type_vars_to_classes(annotation_result) ) - return True diff --git a/jedi/evaluate/context/typing.py b/jedi/evaluate/context/typing.py index fff5bb3a..3cccb34c 100644 --- a/jedi/evaluate/context/typing.py +++ b/jedi/evaluate/context/typing.py @@ -41,9 +41,9 @@ class TypingName(AbstractTreeName): class _BaseTypingContext(Context): - def __init__(self, name): + def __init__(self, evaluator, name): super(_BaseTypingContext, self).__init__( - name.parent_context.evaluator, + evaluator, parent_context=name.parent_context, ) self._name = name @@ -72,6 +72,9 @@ class _BaseTypingContext(Context): class TypingModuleName(NameWrapper): + def __init__(self, *args, **kwargs): + assert not isinstance(args[0], TypingModuleName) + return super().__init__(*args, **kwargs) def infer(self): return ContextSet.from_iterable(self._remap()) @@ -83,28 +86,28 @@ class TypingModuleName(NameWrapper): except KeyError: pass else: - yield TypeAlias(evaluator, self.tree_name, actual) + yield TypeAlias.create_cached(evaluator, self.tree_name, actual) return if name in _PROXY_CLASS_TYPES: - yield TypingClassContext(self) + yield TypingClassContext(evaluator, self) elif name in _PROXY_TYPES: - yield TypingContext(self) + yield TypingContext.create_cached(evaluator, self) elif name == 'runtime': # We don't want anything here, not sure what this function is # supposed to do, since it just appears in the stubs and shouldn't # have any effects there (because it's never executed). return elif name == 'TypeVar': - yield TypeVarClass(self) + yield TypeVarClass.create_cached(evaluator, self) elif name == 'Any': - yield Any(self) + yield Any.create_cached(evaluator, self) elif name == 'TYPE_CHECKING': # This is needed for e.g. imports that are only available for type # checking or are in cycles. The user can then check this variable. yield builtin_from_name(evaluator, u'True') elif name == 'overload': - yield OverloadFunction(self) + yield OverloadFunction.create_cached(evaluator, self) elif name == 'cast': # TODO implement cast for c in self._wrapped_name.infer(): # Fuck my life Python 2 @@ -128,8 +131,8 @@ class TypingModuleFilterWrapper(FilterWrapper): class _WithIndexBase(_BaseTypingContext): - def __init__(self, name, index_context, context_of_index): - super(_WithIndexBase, self).__init__(name) + def __init__(self, evaluator, name, index_context, context_of_index): + super(_WithIndexBase, self).__init__(evaluator, name) self._index_context = index_context self._context_of_index = context_of_index @@ -176,7 +179,8 @@ class TypingContext(_BaseTypingContext): def py__getitem__(self, index_context_set, contextualized_node): return ContextSet.from_iterable( - self.index_class( + self.index_class.create_cached( + self.evaluator, self._name, index_context, context_of_index=contextualized_node.context) @@ -322,17 +326,6 @@ class Any(_BaseTypingContext): return NO_CONTEXTS -class GenericClass(object): - def __init__(self, class_context, ): - self._class_context = class_context - - def __getattr__(self, name): - return getattr(self._class_context, name) - - def __repr__(self): - return '%s(%s)' % (self.__class__.__name__, self._class_context) - - class TypeVarClass(_BaseTypingContext): def py__call__(self, arguments): unpacked = arguments.unpack() @@ -344,7 +337,7 @@ class TypeVarClass(_BaseTypingContext): debug.warning('Found a variable without a name %s', arguments) return NO_CONTEXTS - return ContextSet(TypeVar(self._name, var_name, unpacked)) + return ContextSet(TypeVar.create_cached(self.evaluator, self._name, var_name, unpacked)) def _find_string_name(self, lazy_context): if lazy_context is None: @@ -363,8 +356,8 @@ class TypeVarClass(_BaseTypingContext): class TypeVar(_BaseTypingContext): - def __init__(self, class_name, var_name, unpacked_args): - super(TypeVar, self).__init__(class_name) + def __init__(self, evaluator, class_name, var_name, unpacked_args): + super(TypeVar, self).__init__(evaluator, class_name) self.var_name = var_name self._constraints_lazy_contexts = [] @@ -541,6 +534,9 @@ class AnnotatedClass(_AbstractAnnotatedClass): def get_given_types(self): return list(_iter_over_arguments(self._index_context, self._context_of_index)) + def is_same_class(self, other): + return self == other + class AnnotatedSubClass(_AbstractAnnotatedClass): def __init__(self, evaluator, parent_context, tree_node, given_types): @@ -561,7 +557,7 @@ class LazyAnnotatedBaseClass(object): for base in self._lazy_base_class.infer(): if isinstance(base, _AbstractAnnotatedClass): # Here we have to recalculate the given types. - yield AnnotatedSubClass( + yield AnnotatedSubClass.create_cached( base.evaluator, base.parent_context, base.tree_node, diff --git a/jedi/plugins/typeshed.py b/jedi/plugins/typeshed.py index 53511d64..8e5b3e8c 100644 --- a/jedi/plugins/typeshed.py +++ b/jedi/plugins/typeshed.py @@ -287,8 +287,6 @@ class StubParserTreeFilter(ParserTreeFilter): # for all API accesses. Otherwise the user will be directed to the # non-stub positions (see NameWithStub). n = TreeNameDefinition(self.context, name) - if isinstance(self.context, TypingModuleWrapper): - n = TypingModuleName(n) if len(non_stub_names): for non_stub_name in non_stub_names: if isinstance(non_stub_name, CompiledName):