diff --git a/jedi/evaluate/base_context.py b/jedi/evaluate/base_context.py index b2e80327..f81f15b9 100644 --- a/jedi/evaluate/base_context.py +++ b/jedi/evaluate/base_context.py @@ -17,23 +17,6 @@ from jedi.evaluate.utils import safe_property from jedi.evaluate.cache import evaluator_as_method_param_cache -def _is_same_class(class1, class2): - if class1 == class2: - return True - - try: - comp_func = class1.is_same_class - except AttributeError: - try: - comp_func = class2.is_same_class - except AttributeError: - return False - else: - return comp_func(class1) - else: - return comp_func(class2) - - class HelperContextMixin: @classmethod @evaluator_as_method_param_cache() @@ -63,10 +46,13 @@ class HelperContextMixin: def is_sub_class_of(self, class_context): from jedi.evaluate.context.klass import py__mro__ for cls in py__mro__(self): - if _is_same_class(cls, class_context): + if cls.is_same_class(class_context): return True return False + def is_same_class(self, class2): + return self == class2 + class Context(HelperContextMixin, BaseContext): """ @@ -304,12 +290,6 @@ class ContextSet(BaseContextSet): def get_item(self, *args, **kwargs): return ContextSet.from_sets(_getitem(c, *args, **kwargs) for c in self._set) - def is_sub_class_of(self, class_context): - for c in self._set: - if c.is_sub_class_of(class_context): - return True - return False - NO_CONTEXTS = ContextSet() diff --git a/jedi/evaluate/context/function.py b/jedi/evaluate/context/function.py index bfb485fc..dcb6e1c9 100644 --- a/jedi/evaluate/context/function.py +++ b/jedi/evaluate/context/function.py @@ -329,24 +329,18 @@ def signature_matches(function_context, arguments): return False if param_node.annotation is not None: - annotation_result = function_context.evaluator.eval_element( + annotation_contexts = function_context.evaluator.eval_element( function_context.parent_context, param_node.annotation ) - return any( - argument.infer().py__class__().is_sub_class_of(c) - for c in _type_vars_to_classes(annotation_result) - ) + argument_contexts = argument.infer().py__class__() + if not any(c1.is_sub_class_of(c2) + for c1 in argument_contexts + for c2 in annotation_contexts): + return False return True -def _type_vars_to_classes(context_set): - return ContextSet.from_sets( - context.get_classes() if isinstance(context, TypeVar) else set([context]) - for context in context_set - ) - - def has_same_class(context_set1, context_set2): for c1 in context_set1: for c2 in context_set2: diff --git a/jedi/evaluate/context/klass.py b/jedi/evaluate/context/klass.py index 53c0a897..278f5cad 100644 --- a/jedi/evaluate/context/klass.py +++ b/jedi/evaluate/context/klass.py @@ -243,6 +243,7 @@ class ClassContext(use_metaclass(CachedMetaClass, TreeContext)): def py__getitem__(self, index_context_set, contextualized_node): from jedi.evaluate.context.typing import TypingClassMixin, AnnotatedClass + #from pprint import pprint for cls in py__mro__(self): if isinstance(cls, TypingClassMixin): # TODO get the right classes. diff --git a/jedi/evaluate/context/typing.py b/jedi/evaluate/context/typing.py index 3cccb34c..44e237f3 100644 --- a/jedi/evaluate/context/typing.py +++ b/jedi/evaluate/context/typing.py @@ -41,16 +41,13 @@ class TypingName(AbstractTreeName): class _BaseTypingContext(Context): - def __init__(self, evaluator, name): - super(_BaseTypingContext, self).__init__( - evaluator, - parent_context=name.parent_context, - ) - self._name = name + def __init__(self, evaluator, parent_context, tree_name): + super(_BaseTypingContext, self).__init__(evaluator, parent_context) + self._tree_name = tree_name @property def tree_node(self): - return self._name.tree_name + return self._tree_name def get_filters(self, *args, **kwargs): # TODO this is obviously wrong. @@ -65,16 +62,13 @@ class _BaseTypingContext(Context): @property def name(self): - return ContextName(self, self._name.tree_name) + return ContextName(self, self.tree_name) def __repr__(self): - return '%s(%s)' % (self.__class__.__name__, self._name.string_name) + return '%s(%s)' % (self.__class__.__name__, self._tree_name.value) class TypingModuleName(NameWrapper): - def __init__(self, *args, **kwargs): - assert not isinstance(args[0], TypingModuleName) - return super().__init__(*args, **kwargs) def infer(self): return ContextSet.from_iterable(self._remap()) @@ -86,28 +80,28 @@ class TypingModuleName(NameWrapper): except KeyError: pass else: - yield TypeAlias.create_cached(evaluator, self.tree_name, actual) + yield TypeAlias.create_cached(evaluator, self.parent_context, self.tree_name, actual) return if name in _PROXY_CLASS_TYPES: - yield TypingClassContext(evaluator, self) + yield TypingClassContext(evaluator, self.parent_context, self.tree_name) elif name in _PROXY_TYPES: - yield TypingContext.create_cached(evaluator, self) + yield TypingContext.create_cached(evaluator, self.parent_context, self.tree_name) elif name == 'runtime': # We don't want anything here, not sure what this function is # supposed to do, since it just appears in the stubs and shouldn't # have any effects there (because it's never executed). return elif name == 'TypeVar': - yield TypeVarClass.create_cached(evaluator, self) + yield TypeVarClass.create_cached(evaluator, self.parent_context, self.tree_name) elif name == 'Any': - yield Any.create_cached(evaluator, self) + yield Any.create_cached(evaluator, self.parent_context, self.tree_name) elif name == 'TYPE_CHECKING': # This is needed for e.g. imports that are only available for type # checking or are in cycles. The user can then check this variable. yield builtin_from_name(evaluator, u'True') elif name == 'overload': - yield OverloadFunction.create_cached(evaluator, self) + yield OverloadFunction.create_cached(evaluator, self.parent_context, self.tree_name) elif name == 'cast': # TODO implement cast for c in self._wrapped_name.infer(): # Fuck my life Python 2 @@ -131,15 +125,15 @@ class TypingModuleFilterWrapper(FilterWrapper): class _WithIndexBase(_BaseTypingContext): - def __init__(self, evaluator, name, index_context, context_of_index): - super(_WithIndexBase, self).__init__(evaluator, name) + def __init__(self, evaluator, parent_context, name, index_context, context_of_index): + super(_WithIndexBase, self).__init__(evaluator, parent_context, name) self._index_context = index_context self._context_of_index = context_of_index def __repr__(self): return '<%s: %s[%s]>' % ( self.__class__.__name__, - self._name.string_name, + self._tree_name.value, self._index_context, ) @@ -151,7 +145,7 @@ class _WithIndexBase(_BaseTypingContext): class TypingContextWithIndex(_WithIndexBase): def execute_annotation(self): - string_name = self._name.string_name + string_name = self._tree_name.value if string_name == 'Union': # This is kind of a special case, because we have Unions (in Jedi @@ -170,7 +164,13 @@ class TypingContextWithIndex(_WithIndexBase): return self._index_context.execute_annotation() cls = globals()[string_name] - return ContextSet(cls(self._name, self._index_context, self._context_of_index)) + return ContextSet(cls( + self.evaluator, + self.parent_context, + self._tree_name, + self._index_context, + self._context_of_index + )) class TypingContext(_BaseTypingContext): @@ -181,7 +181,8 @@ class TypingContext(_BaseTypingContext): return ContextSet.from_iterable( self.index_class.create_cached( self.evaluator, - self._name, + self.parent_context, + self._tree_name, index_context, context_of_index=contextualized_node.context) for index_context in index_context_set @@ -225,8 +226,9 @@ def _iter_over_arguments(maybe_tuple_context, defining_context): class TypeAlias(object): - def __init__(self, evaluator, origin_tree_name, actual): + def __init__(self, evaluator, parent_context, origin_tree_name, actual): self.evaluator = evaluator + self.parent_context = parent_context self._origin_tree_name = origin_tree_name self._actual = actual # e.g. builtins.list @@ -337,7 +339,13 @@ class TypeVarClass(_BaseTypingContext): debug.warning('Found a variable without a name %s', arguments) return NO_CONTEXTS - return ContextSet(TypeVar.create_cached(self.evaluator, self._name, var_name, unpacked)) + return ContextSet(TypeVar.create_cached( + self.evaluator, + self.parent_context, + self._tree_name, + var_name, + unpacked + )) def _find_string_name(self, lazy_context): if lazy_context is None: @@ -356,8 +364,8 @@ class TypeVarClass(_BaseTypingContext): class TypeVar(_BaseTypingContext): - def __init__(self, evaluator, class_name, var_name, unpacked_args): - super(TypeVar, self).__init__(evaluator, class_name) + def __init__(self, evaluator, parent_context, tree_name, var_name, unpacked_args): + super(TypeVar, self).__init__(evaluator, parent_context, tree_name) self.var_name = var_name self._constraints_lazy_contexts = [] @@ -380,7 +388,7 @@ class TypeVar(_BaseTypingContext): def get_filters(self, *args, **kwargs): return iter([]) - def get_classes(self): + def _get_classes(self): if self._bound_lazy_context is not None: return self._bound_lazy_context.infer() if self._constraints_lazy_contexts: @@ -397,7 +405,7 @@ class TypeVar(_BaseTypingContext): ) def execute_annotation(self): - return self.get_classes().execute_annotation() + return self._get_classes().execute_annotation() def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self.var_name) @@ -508,6 +516,30 @@ class _AbstractAnnotatedClass(ClassContext): if subscriptlist.type != 'subscript': yield subscriptlist + def is_same_class(self, other): + if not isinstance(other, _AbstractAnnotatedClass): + return False + + if self.tree_node != other.tree_node: + # TODO not sure if this is nice. + return False + + given_params1 = self.get_given_types() + given_params2 = other.get_given_types() + if len(given_params1) != len(given_params2): + # If the amount of type vars doesn't match, the class doesn't + # match. + return False + + # Now compare generics + return all( + any( + cls1.is_same_class(cls2) + for cls1 in class_set1 + for cls2 in class_set2 + ) for class_set1, class_set2 in zip(given_params1, given_params2) + ) + def get_given_types(self): raise NotImplementedError @@ -534,9 +566,6 @@ class AnnotatedClass(_AbstractAnnotatedClass): def get_given_types(self): return list(_iter_over_arguments(self._index_context, self._context_of_index)) - def is_same_class(self, other): - return self == other - class AnnotatedSubClass(_AbstractAnnotatedClass): def __init__(self, evaluator, parent_context, tree_node, given_types): diff --git a/jedi/plugins/typeshed.py b/jedi/plugins/typeshed.py index 8e5b3e8c..ee4167fe 100644 --- a/jedi/plugins/typeshed.py +++ b/jedi/plugins/typeshed.py @@ -287,6 +287,8 @@ class StubParserTreeFilter(ParserTreeFilter): # for all API accesses. Otherwise the user will be directed to the # non-stub positions (see NameWithStub). n = TreeNameDefinition(self.context, name) + if isinstance(self.context, TypingModuleWrapper): + n = TypingModuleName(n) if len(non_stub_names): for non_stub_name in non_stub_names: if isinstance(non_stub_name, CompiledName): @@ -446,7 +448,8 @@ class CompiledStubFunctionContext(_StubContextWithCompiled): class TypingModuleWrapper(StubOnlyModuleContext): - def get_filters(self, *args, **kwargs): + # TODO should use this instead of the isinstance check + def get_filterss(self, *args, **kwargs): filters = super(TypingModuleWrapper, self).get_filters(*args, **kwargs) yield TypingModuleFilterWrapper(next(filters)) for f in filters: