Refactor dict/set/list/tuple literal generic inferring

This commit is contained in:
Dave Halter
2018-09-27 00:01:35 +02:00
parent b5b0214c3c
commit 8e8271cf54
4 changed files with 28 additions and 10 deletions

View File

@@ -62,6 +62,9 @@ class HelperContextMixin:
return False return False
def is_same_class(self, class2): def is_same_class(self, class2):
# Class matching should prefer comparisons that are not this function.
if type(class2).is_same_class != HelperContextMixin.is_same_class:
return class2.is_same_class(self)
return self == class2 return self == class2

View File

@@ -292,8 +292,8 @@ class FunctionExecutionContext(TreeContext):
if debug.enable_notice: if debug.enable_notice:
signature = parser_utils.get_call_signature(self.tree_node) signature = parser_utils.get_call_signature(self.tree_node)
if matches: if matches:
debug.dbg("Overloading match: %s@%s", debug.dbg("Overloading match: %s@%s (%s)",
signature, self.tree_node.start_pos[0], color='BLUE') signature, self.tree_node.start_pos[0], self.var_args, color='BLUE')
else: else:
debug.dbg("Overloading no match: %s@%s (%s)", debug.dbg("Overloading no match: %s@%s (%s)",
signature, self.tree_node.start_pos[0], self.var_args, color='BLUE') signature, self.tree_node.start_pos[0], self.var_args, color='BLUE')

View File

@@ -191,6 +191,11 @@ class ComprehensionMixin(object):
return "<%s of %s>" % (type(self).__name__, self._atom) return "<%s of %s>" % (type(self).__name__, self._atom)
class _DictMixin(object):
def _get_generics(self):
return tuple(c_set.py__class__() for c_set in self.get_mapping_item_contexts())
class Sequence(BuiltinOverwrite, IterableMixin): class Sequence(BuiltinOverwrite, IterableMixin):
api_type = u'instance' api_type = u'instance'
@@ -198,11 +203,17 @@ class Sequence(BuiltinOverwrite, IterableMixin):
def name(self): def name(self):
return compiled.CompiledContextName(self, self.array_type) return compiled.CompiledContextName(self, self.array_type)
def _get_generics(self):
return (self.merge_types_of_iterate().py__class__(),)
@memoize_method @memoize_method
def get_object(self): def get_object(self):
from jedi.evaluate.context.typing import AnnotatedSubClass
klass = compiled.builtin_from_name(self.evaluator, self.array_type) klass = compiled.builtin_from_name(self.evaluator, self.array_type)
instance, = klass.execute_evaluated(self) return AnnotatedSubClass(
return instance self.evaluator, klass.parent_context, klass.tree_node,
self._get_generics()
).execute_annotation()
def py__bool__(self): def py__bool__(self):
return None # We don't know the length, because of appends. return None # We don't know the length, because of appends.
@@ -237,7 +248,7 @@ class SetComprehension(ComprehensionMixin, Sequence):
array_type = u'set' array_type = u'set'
class DictComprehension(ComprehensionMixin, Sequence): class DictComprehension(_DictMixin, ComprehensionMixin, Sequence):
array_type = u'dict' array_type = u'dict'
def _get_comp_for(self): def _get_comp_for(self):
@@ -413,7 +424,7 @@ class SequenceLiteralContext(Sequence):
return "<%s of %s>" % (self.__class__.__name__, self.atom) return "<%s of %s>" % (self.__class__.__name__, self.atom)
class DictLiteralContext(SequenceLiteralContext): class DictLiteralContext(_DictMixin, SequenceLiteralContext):
array_type = u'dict' array_type = u'dict'
def __init__(self, evaluator, defining_context, atom): def __init__(self, evaluator, defining_context, atom):
@@ -479,7 +490,7 @@ class FakeSequence(_FakeArray):
return "<%s of %s>" % (type(self).__name__, self._lazy_context_list) return "<%s of %s>" % (type(self).__name__, self._lazy_context_list)
class FakeDict(_FakeArray): class FakeDict(_DictMixin, _FakeArray):
def __init__(self, evaluator, dct): def __init__(self, evaluator, dct):
super(FakeDict, self).__init__(evaluator, dct, u'dict') super(FakeDict, self).__init__(evaluator, dct, u'dict')
self._dct = dct self._dct = dct

View File

@@ -1,5 +1,6 @@
from collections import defaultdict from collections import defaultdict
from jedi import debug
from jedi.evaluate.utils import PushBackIterator from jedi.evaluate.utils import PushBackIterator
from jedi.evaluate import analysis from jedi.evaluate import analysis
from jedi.evaluate.lazy_context import LazyKnownContext, \ from jedi.evaluate.lazy_context import LazyKnownContext, \
@@ -49,9 +50,12 @@ class ExecutedParam(object):
# If we cannot infer annotations - or there aren't any - pretend # If we cannot infer annotations - or there aren't any - pretend
# that the signature matches. # that the signature matches.
return True return True
return any(c1.is_sub_class_of(c2) matches = any(c1.is_sub_class_of(c2)
for c1 in argument_contexts for c1 in argument_contexts
for c2 in annotations) for c2 in annotations)
debug.dbg("signature compare %s: %s <=> %s",
matches, argument_contexts, annotations, color='BLUE')
return matches
@property @property
def var_args(self): def var_args(self):