diff --git a/jedi/inference/gradual/base.py b/jedi/inference/gradual/base.py index 88363b6d..6ea65342 100644 --- a/jedi/inference/gradual/base.py +++ b/jedi/inference/gradual/base.py @@ -142,39 +142,7 @@ class AnnotatedClassContext(ClassContext): yield self._value.get_type_var_filter() -class AbstractAnnotatedClass(ClassMixin, ValueWrapper): - def get_type_var_filter(self): - return TypeVarFilter(self.get_generics(), self.list_type_vars()) - - def is_same_class(self, other): - if not isinstance(other, AbstractAnnotatedClass): - return False - - if self.tree_node != other.tree_node: - # TODO not sure if this is nice. - return False - given_params1 = self.get_generics() - given_params2 = other.get_generics() - - if len(given_params1) != len(given_params2): - # If the amount of type vars doesn't match, the class doesn't - # match. - return False - - # Now compare generics - return all( - any( - # TODO why is this ordering the correct one? - cls2.is_same_class(cls1) - for cls1 in class_set1 - for cls2 in class_set2 - ) for class_set1, class_set2 in zip(given_params1, given_params2) - ) - - def py__call__(self, arguments): - instance, = super(AbstractAnnotatedClass, self).py__call__(arguments) - return ValueSet([InstanceWrapper(instance)]) - +class DefineGenericBase(ValueWrapper): def get_generics(self): raise NotImplementedError @@ -205,8 +173,30 @@ class AbstractAnnotatedClass(ClassMixin, ValueWrapper): generics=tuple(new_generics) )]) - def _as_context(self): - return AnnotatedClassContext(self) + def is_same_class(self, other): + if not isinstance(other, DefineGenericBase): + return False + + if self.tree_node != other.tree_node: + # TODO not sure if this is nice. + return False + given_params1 = self.get_generics() + given_params2 = other.get_generics() + + if len(given_params1) != len(given_params2): + # If the amount of type vars doesn't match, the class doesn't + # match. + return False + + # Now compare generics + return all( + any( + # TODO why is this ordering the correct one? + cls2.is_same_class(cls1) + for cls1 in class_set1 + for cls2 in class_set2 + ) for class_set1, class_set2 in zip(given_params1, given_params2) + ) def __repr__(self): return '<%s: %s%s>' % ( @@ -215,6 +205,18 @@ class AbstractAnnotatedClass(ClassMixin, ValueWrapper): list(self.get_generics()), ) + +class AbstractAnnotatedClass(ClassMixin, DefineGenericBase): + def get_type_var_filter(self): + return TypeVarFilter(self.get_generics(), self.list_type_vars()) + + def py__call__(self, arguments): + instance, = super(AbstractAnnotatedClass, self).py__call__(arguments) + return ValueSet([InstanceWrapper(instance)]) + + def _as_context(self): + return AnnotatedClassContext(self) + @to_list def py__bases__(self): for base in self._wrapped_value.py__bases__():