Infer dict.get() in a fancy way

This commit is contained in:
Dave Halter
2018-09-19 01:50:35 +02:00
parent 57fa5f5bd9
commit 9807a7f038
6 changed files with 80 additions and 38 deletions

View File

@@ -339,7 +339,7 @@ def signature_matches(function_context, arguments):
return False # TODO allow this return False # TODO allow this
annotation_contexts = function_context.evaluator.eval_element( annotation_contexts = function_context.evaluator.eval_element(
function_context.parent_context, function_context.get_default_param_context(),
param_node.annotation param_node.annotation
) )
argument_contexts = argument.infer().py__class__() argument_contexts = argument.infer().py__class__()

View File

@@ -188,26 +188,25 @@ class Sequence(BuiltinOverwrite, IterableMixin):
@memoize_method @memoize_method
def get_object(self): def get_object(self):
klass = compiled.builtin_from_name(self.evaluator, self.array_type) klass = compiled.builtin_from_name(self.evaluator, self.array_type)
return self._annotate_class(klass) annotated_instance, = self._annotate_class(klass).execute_evaluated()
return annotated_instance
def _annotate_class(self, klass): def _annotate_class(self, klass):
from jedi.evaluate.pep0484 import define_type_vars_for_execution from jedi.evaluate.pep0484 import define_type_vars_for_execution
instance, = klass.execute_evaluated(self) instance, = klass.execute(self)
for init_function in self._get_init_functions(instance): for execution_context in self._get_init_executions(instance):
# Just take the first result, it should always be one, because we # Just take the first result, it should always be one, because we
# control the typeshed code. # control the typeshed code.
execution_context = init_function.get_function_execution()
return define_type_vars_for_execution( return define_type_vars_for_execution(
ContextSet(klass), ContextSet(klass),
execution_context, execution_context,
klass.find_annotation_variables() klass.list_type_vars()
) )
return instance
assert "Should never land here, probably an issue with typeshed changes" assert "Should never land here, probably an issue with typeshed changes"
def _get_init_functions(self, instance): def _get_init_executions(self, instance):
from jedi.evaluate import arguments from jedi.evaluate import arguments
from jedi.evaluate.context.instance import InstanceArguments from jedi.evaluate.context.instance import InstanceArguments
for init in instance.py__getattribute__('__init__'): for init in instance.py__getattribute__('__init__'):
@@ -216,12 +215,10 @@ class Sequence(BuiltinOverwrite, IterableMixin):
except AttributeError: except AttributeError:
continue continue
else: else:
arguments = InstanceArguments( base_args = arguments.ValuesArguments([ContextSet(self)])
instance, arguments = InstanceArguments(instance, base_args)
arguments.ValuesArguments([ContextSet(self)]) for func in method(arguments):
) yield func.get_function_execution(base_args)
for x in method(arguments):
yield x
def py__bool__(self): def py__bool__(self):
return None # We don't know the length, because of appends. return None # We don't know the length, because of appends.
@@ -439,6 +436,15 @@ class DictLiteralContext(SequenceLiteralContext):
return ContextSet(FakeSequence(self.evaluator, u'list', lazy_contexts)) return ContextSet(FakeSequence(self.evaluator, u'list', lazy_contexts))
def _dict_keys(self):
return ContextSet.from_sets(
self._defining_context.eval_node(k)
for k, v in self.get_tree_entries()
)
def get_mapping_item_contexts(self):
return self._dict_keys(), self._dict_values()
class _FakeArray(SequenceLiteralContext): class _FakeArray(SequenceLiteralContext):
def __init__(self, evaluator, container, type): def __init__(self, evaluator, container, type):

View File

@@ -168,7 +168,7 @@ class ClassContext(use_metaclass(CachedMetaClass, TreeContext)):
api_type = u'class' api_type = u'class'
@evaluator_method_cache() @evaluator_method_cache()
def find_annotation_variables(self): def list_type_vars(self):
found = [] found = []
arglist = self.tree_node.get_super_arglist() arglist = self.tree_node.get_super_arglist()
if arglist is None: if arglist is None:

View File

@@ -488,7 +488,7 @@ class TypeVarFilter(object):
class _AbstractAnnotatedClass(ClassContext): class _AbstractAnnotatedClass(ClassContext):
def get_type_var_filter(self): def get_type_var_filter(self):
return TypeVarFilter(self.get_given_types(), self.find_annotation_variables()) return TypeVarFilter(self.get_given_types(), self.list_type_vars())
def get_filters(self, search_global=False, *args, **kwargs): def get_filters(self, search_global=False, *args, **kwargs):
for f in super(_AbstractAnnotatedClass, self).get_filters(search_global, *args, **kwargs): for f in super(_AbstractAnnotatedClass, self).get_filters(search_global, *args, **kwargs):
@@ -530,7 +530,7 @@ class _AbstractAnnotatedClass(ClassContext):
return '<%s: %s%s>' % ( return '<%s: %s%s>' % (
self.__class__.__name__, self.__class__.__name__,
self.name.string_name, self.name.string_name,
self.get_given_types(), list(self.get_given_types()),
) )
@to_list @to_list

View File

@@ -30,8 +30,9 @@ from jedi.evaluate.cache import evaluator_method_cache
from jedi.evaluate import compiled from jedi.evaluate import compiled
from jedi.evaluate.base_context import NO_CONTEXTS, ContextSet from jedi.evaluate.base_context import NO_CONTEXTS, ContextSet
from jedi.evaluate.lazy_context import LazyTreeContext from jedi.evaluate.lazy_context import LazyTreeContext
from jedi.evaluate.context import ModuleContext from jedi.evaluate.context import ModuleContext, ClassContext
from jedi.evaluate.context.typing import TypeVar, AnnotatedClass, AnnotatedSubClass from jedi.evaluate.context.typing import TypeVar, AnnotatedClass, \
AnnotatedSubClass
from jedi.evaluate.helpers import is_string, execute_evaluated from jedi.evaluate.helpers import is_string, execute_evaluated
from jedi import debug from jedi import debug
from jedi import parser_utils from jedi import parser_utils
@@ -218,7 +219,7 @@ def infer_return_types(function_execution_context):
context = function_execution_context.function_context.get_default_param_context() context = function_execution_context.function_context.get_default_param_context()
unknown_type_vars = list(find_unknown_type_vars(context, annotation)) unknown_type_vars = list(find_unknown_type_vars(context, annotation))
if not unknown_type_vars: if not unknown_type_vars:
return context.eval_node(annotation) return context.eval_node(annotation).execute_annotation()
return define_type_vars_for_execution( return define_type_vars_for_execution(
context.eval_node(annotation), context.eval_node(annotation),
@@ -232,11 +233,11 @@ def define_type_vars_for_execution(to_define_contexts, execution_context,
all_annotations = py__annotations__(execution_context.tree_node) all_annotations = py__annotations__(execution_context.tree_node)
return _define_type_vars( return _define_type_vars(
to_define_contexts, to_define_contexts,
_infer_type_vars(execution_context, all_annotations), _infer_type_vars_for_execution(execution_context, all_annotations),
) )
def _infer_type_vars(execution_context, annotation_dict): def _infer_type_vars_for_execution(execution_context, annotation_dict):
""" """
Some functions use type vars that are not defined by the class, but rather Some functions use type vars that are not defined by the class, but rather
only defined in the function. See for example `iter`. In those cases we only defined in the function. See for example `iter`. In those cases we
@@ -263,49 +264,58 @@ def _infer_type_vars(execution_context, annotation_dict):
for ann in annotation_context_set: for ann in annotation_context_set:
_merge_type_var_dicts( _merge_type_var_dicts(
annotation_variable_results, annotation_variable_results,
_unpack_type_vars(ann, actual_context_set), _infer_type_vars(ann, actual_context_set),
) )
return annotation_variable_results return annotation_variable_results
def _define_type_vars(annotation_contexts, type_var_dict): def _define_type_vars(annotation_contexts, type_var_dict):
def remap_type_vars(type_var_contexts): def remap_type_vars(cls):
return ContextSet.from_sets( for type_var in cls.list_type_vars():
type_var_dict.get(type_var, ContextSet(type_var)) yield type_var_dict.get(type_var.py__name__(), NO_CONTEXTS)
for type_var in type_var_contexts
)
if not type_var_dict: if not type_var_dict:
return annotation_contexts return annotation_contexts
context_set = ContextSet() context_set = ContextSet()
for annotation_context in annotation_contexts: for annotation_context in annotation_contexts:
if isinstance(annotation_context, AnnotatedClass): if isinstance(annotation_context, ClassContext):
context_set |= ContextSet.from_iterable([ context_set |= ContextSet.from_iterable([
AnnotatedSubClass( AnnotatedSubClass(
annotation_context.evaluator, annotation_context.evaluator,
annotation_context.parent_context, annotation_context.parent_context,
annotation_context.tree_node, annotation_context.tree_node,
tuple(remap_type_vars(tcs) given_types=tuple(remap_type_vars(annotation_context))
for tcs in annotation_context.get_given_types())
) )
]) ])
return context_set return context_set
def _merge_type_var_dicts(base_dict, new_dict): def _merge_type_var_dicts(base_dict, new_dict):
for type_var, contexts in new_dict.items(): for type_var_name, contexts in new_dict.items():
try: try:
base_dict[type_var] = contexts base_dict[type_var_name] |= contexts
except KeyError: except KeyError:
base_dict[type_var] |= contexts base_dict[type_var_name] = contexts
def _unpack_type_vars(annotation_context, context_set): def _infer_type_vars(annotation_context, context_set):
"""
This function tries to find information about undefined type vars and
returns a dict from type var name to context set.
This is for example important to understand what `iter([1])` returns.
According to typeshed, `iter` returns an `Iterator[_T]`:
def iter(iterable: Iterable[_T]) -> Iterator[_T]: ...
This functions would generate `int` for `_T` in this case, because it
unpacks the `Iterable`.
"""
type_var_dict = {} type_var_dict = {}
if isinstance(annotation_context, TypeVar): if isinstance(annotation_context, TypeVar):
return {annotation_context: context_set.py__class__()} return {annotation_context.py__name__(): context_set.py__class__()}
elif isinstance(annotation_context, AnnotatedClass): elif isinstance(annotation_context, AnnotatedClass):
name = annotation_context.py__name__() name = annotation_context.py__name__()
if name == 'Iterable': if name == 'Iterable':
@@ -314,11 +324,37 @@ def _unpack_type_vars(annotation_context, context_set):
for nested_annotation_context in given[0]: for nested_annotation_context in given[0]:
_merge_type_var_dicts( _merge_type_var_dicts(
type_var_dict, type_var_dict,
_unpack_type_vars( _infer_type_vars(
nested_annotation_context, nested_annotation_context,
context_set.merge_types_of_iterate() context_set.merge_types_of_iterate()
) )
) )
elif name == 'Mapping':
given = annotation_context.get_given_types()
if len(given) == 2:
for context in context_set:
try:
method = context.get_mapping_item_contexts
except AttributeError:
continue
key_contexts, value_contexts = method()
for nested_annotation_context in given[0]:
_merge_type_var_dicts(
type_var_dict,
_infer_type_vars(
nested_annotation_context,
key_contexts,
)
)
for nested_annotation_context in given[1]:
_merge_type_var_dicts(
type_var_dict,
_infer_type_vars(
nested_annotation_context,
value_contexts,
)
)
return type_var_dict return type_var_dict

View File

@@ -210,7 +210,7 @@ dic2 = {'asdf': 3, 'b': 'str'}
#? int() #? int()
dic2['asdf'] dic2['asdf']
# TODO for now get doesn't work properly when used with a literal. # TODO for now get doesn't work properly when used with a literal.
#? None #? None int() str()
dic2.get('asdf') dic2.get('asdf')
# string literal # string literal