1
0
forked from VimPlug/jedi

Fix an issue with type vars

This commit is contained in:
Dave Halter
2018-09-22 21:00:42 +02:00
parent 389d4e3d9c
commit 994e7d1910
7 changed files with 63 additions and 61 deletions

View File

@@ -159,7 +159,7 @@ class AnonymousArguments(AbstractArguments):
execution_context.evaluator, execution_context.evaluator,
execution_context, execution_context,
execution_context.tree_node execution_context.tree_node
) ), []
def __repr__(self): def __repr__(self):
return '%s()' % self.__class__.__name__ return '%s()' % self.__class__.__name__

View File

@@ -27,6 +27,9 @@ class InstanceExecutedParam(object):
def infer(self): def infer(self):
return ContextSet(self._instance) return ContextSet(self._instance)
def matches_signature(self):
return True
class AnonymousInstanceArguments(AnonymousArguments): class AnonymousInstanceArguments(AnonymousArguments):
def __init__(self, instance): def __init__(self, instance):
@@ -42,14 +45,14 @@ class AnonymousInstanceArguments(AnonymousArguments):
if len(tree_params) == 1: if len(tree_params) == 1:
# If the only param is self, we don't need to try to find # If the only param is self, we don't need to try to find
# executions of this function, we have all the params already. # executions of this function, we have all the params already.
return [self_param] return [self_param], []
executed_params = list(search_params( executed_params = list(search_params(
execution_context.evaluator, execution_context.evaluator,
execution_context, execution_context,
execution_context.tree_node execution_context.tree_node
)) ))
executed_params[0] = self_param executed_params[0] = self_param
return [], executed_params return executed_params, []
class AbstractInstanceContext(Context): class AbstractInstanceContext(Context):
@@ -263,7 +266,7 @@ class TreeInstance(AbstractInstanceContext):
@evaluator_method_cache() @evaluator_method_cache()
def get_annotated_class_object(self): def get_annotated_class_object(self):
from jedi.evaluate.pep0484 import define_type_vars_for_execution from jedi.evaluate import pep0484
for func in self._get_annotation_init_functions(): for func in self._get_annotation_init_functions():
# Just take the first result, it should always be one, because we # Just take the first result, it should always be one, because we
@@ -274,13 +277,12 @@ class TreeInstance(AbstractInstanceContext):
# First check if the signature even matches, if not we don't # First check if the signature even matches, if not we don't
# need to infer anything. # need to infer anything.
continue continue
context_set = define_type_vars_for_execution(
ContextSet(self.class_context), all_annotations = pep0484.py__annotations__(execution.tree_node)
execution, return pep0484.define_type_vars(
self.class_context.list_type_vars() self.class_context,
pep0484.infer_type_vars_for_execution(execution, all_annotations),
) )
if context_set:
return next(iter(context_set))
return self.class_context return self.class_context
def _get_annotation_init_functions(self): def _get_annotation_init_functions(self):

View File

@@ -261,20 +261,16 @@ class ClassContext(use_metaclass(CachedMetaClass, TreeContext)):
return self.name.string_name return self.name.string_name
def py__getitem__(self, index_context_set, contextualized_node): def py__getitem__(self, index_context_set, contextualized_node):
from jedi.evaluate.context.typing import TypingClassMixin, AnnotatedClass from jedi.evaluate.context.typing import AnnotatedClass
#from pprint import pprint if not index_context_set:
for cls in py__mro__(self): return ContextSet(self)
if isinstance(cls, TypingClassMixin): return ContextSet.from_iterable(
# TODO get the right classes. AnnotatedClass(
return ContextSet.from_iterable( self.evaluator,
AnnotatedClass( self.parent_context,
self.evaluator, self.tree_node,
self.parent_context, index_context,
self.tree_node, context_of_index=contextualized_node.context,
index_context, )
context_of_index=contextualized_node.context, for index_context in index_context_set
) )
for index_context in index_context_set
)
return super(ClassContext, self).py__getitem__(index_context_set, contextualized_node)

View File

@@ -406,7 +406,7 @@ class TypeVar(_BaseTypingContext):
return ContextSet.from_sets( return ContextSet.from_sets(
l.infer() for l in self._constraints_lazy_contexts l.infer() for l in self._constraints_lazy_contexts
) )
debug.warning('Tried to infer the TypeVar %r without a given type', self._var_name) debug.warning('Tried to infer the TypeVar %s without a given type', self._var_name)
return NO_CONTEXTS return NO_CONTEXTS
@property @property

View File

@@ -69,7 +69,7 @@ def get_executed_params_and_issues(execution_context, var_args):
_add_argument_issue( _add_argument_issue(
default_param_context, default_param_context,
'type-error-too-many-arguments', 'type-error-too-many-arguments',
lazy_context, argument,
message=m message=m
) )
) )

View File

@@ -221,23 +221,34 @@ def infer_return_types(function_execution_context):
if not unknown_type_vars: if not unknown_type_vars:
return context.eval_node(annotation).execute_annotation() return context.eval_node(annotation).execute_annotation()
return define_type_vars_for_execution( annotations_contexts = context.eval_node(annotation)
context.eval_node(annotation), type_var_dict = infer_type_vars_for_execution(function_execution_context, all_annotations)
function_execution_context,
unknown_type_vars, def remap_type_vars(context, type_var_dict):
"""
The TypeVars in the resulting classes have sometimes different names
and we need to check for that, e.g. a signature can be:
def iter(iterable: Iterable[_T]) -> Iterator[_T]: ...
However, the iterator is defined as Iterator[_T_co], which means it has
a different type var name.
"""
if isinstance(context, ClassContext):
return {
to.py__name__(): type_var_dict.get(from_.py__name__(), NO_CONTEXTS)
for from_, to in zip(unknown_type_vars, context.list_type_vars())
}
return type_var_dict
return ContextSet.from_iterable(
define_type_vars(
annotation_context,
remap_type_vars(annotation_context, type_var_dict),
) for annotation_context in annotations_contexts
).execute_annotation() ).execute_annotation()
def define_type_vars_for_execution(to_define_contexts, execution_context, def infer_type_vars_for_execution(execution_context, annotation_dict):
unknown_type_vars):
all_annotations = py__annotations__(execution_context.tree_node)
return _define_type_vars(
to_define_contexts,
_infer_type_vars_for_execution(execution_context, all_annotations),
)
def _infer_type_vars_for_execution(execution_context, annotation_dict):
""" """
Some functions use type vars that are not defined by the class, but rather Some functions use type vars that are not defined by the class, but rather
only defined in the function. See for example `iter`. In those cases we only defined in the function. See for example `iter`. In those cases we
@@ -277,26 +288,19 @@ def _infer_type_vars_for_execution(execution_context, annotation_dict):
return annotation_variable_results return annotation_variable_results
def _define_type_vars(annotation_contexts, type_var_dict): def define_type_vars(annotation_context, type_var_dict):
def remap_type_vars(cls): def remap_type_vars(cls):
for type_var in cls.list_type_vars(): for type_var in cls.list_type_vars():
yield type_var_dict.get(type_var.py__name__(), NO_CONTEXTS) yield type_var_dict.get(type_var.py__name__(), NO_CONTEXTS)
if not type_var_dict: if type_var_dict and isinstance(annotation_context, ClassContext):
return annotation_contexts return AnnotatedSubClass(
annotation_context.evaluator,
context_set = ContextSet() annotation_context.parent_context,
for annotation_context in annotation_contexts: annotation_context.tree_node,
if isinstance(annotation_context, ClassContext): given_types=tuple(remap_type_vars(annotation_context))
context_set |= ContextSet.from_iterable([ )
AnnotatedSubClass( return annotation_context
annotation_context.evaluator,
annotation_context.parent_context,
annotation_context.tree_node,
given_types=tuple(remap_type_vars(annotation_context))
)
])
return context_set
def _merge_type_var_dicts(base_dict, new_dict): def _merge_type_var_dicts(base_dict, new_dict):

View File

@@ -273,9 +273,9 @@ d.values()[0]
x, = d.values() x, = d.values()
#? int() str() #? int() str()
x x
#? int() #? int() str()
d['a'] d['a']
#? int() None #? int() str() None
d.get('a') d.get('a')
# ----------------- # -----------------