Infer dict.get() in a fancy way

This commit is contained in:
Dave Halter
2018-09-19 01:50:35 +02:00
parent 57fa5f5bd9
commit 9807a7f038
6 changed files with 80 additions and 38 deletions

View File

@@ -339,7 +339,7 @@ def signature_matches(function_context, arguments):
return False # TODO allow this
annotation_contexts = function_context.evaluator.eval_element(
function_context.parent_context,
function_context.get_default_param_context(),
param_node.annotation
)
argument_contexts = argument.infer().py__class__()

View File

@@ -188,26 +188,25 @@ class Sequence(BuiltinOverwrite, IterableMixin):
@memoize_method
def get_object(self):
klass = compiled.builtin_from_name(self.evaluator, self.array_type)
return self._annotate_class(klass)
annotated_instance, = self._annotate_class(klass).execute_evaluated()
return annotated_instance
def _annotate_class(self, klass):
from jedi.evaluate.pep0484 import define_type_vars_for_execution
instance, = klass.execute_evaluated(self)
instance, = klass.execute(self)
for init_function in self._get_init_functions(instance):
for execution_context in self._get_init_executions(instance):
# Just take the first result, it should always be one, because we
# control the typeshed code.
execution_context = init_function.get_function_execution()
return define_type_vars_for_execution(
ContextSet(klass),
execution_context,
klass.find_annotation_variables()
klass.list_type_vars()
)
return instance
assert "Should never land here, probably an issue with typeshed changes"
def _get_init_functions(self, instance):
def _get_init_executions(self, instance):
from jedi.evaluate import arguments
from jedi.evaluate.context.instance import InstanceArguments
for init in instance.py__getattribute__('__init__'):
@@ -216,12 +215,10 @@ class Sequence(BuiltinOverwrite, IterableMixin):
except AttributeError:
continue
else:
arguments = InstanceArguments(
instance,
arguments.ValuesArguments([ContextSet(self)])
)
for x in method(arguments):
yield x
base_args = arguments.ValuesArguments([ContextSet(self)])
arguments = InstanceArguments(instance, base_args)
for func in method(arguments):
yield func.get_function_execution(base_args)
def py__bool__(self):
return None # We don't know the length, because of appends.
@@ -439,6 +436,15 @@ class DictLiteralContext(SequenceLiteralContext):
return ContextSet(FakeSequence(self.evaluator, u'list', lazy_contexts))
def _dict_keys(self):
return ContextSet.from_sets(
self._defining_context.eval_node(k)
for k, v in self.get_tree_entries()
)
def get_mapping_item_contexts(self):
return self._dict_keys(), self._dict_values()
class _FakeArray(SequenceLiteralContext):
def __init__(self, evaluator, container, type):

View File

@@ -168,7 +168,7 @@ class ClassContext(use_metaclass(CachedMetaClass, TreeContext)):
api_type = u'class'
@evaluator_method_cache()
def find_annotation_variables(self):
def list_type_vars(self):
found = []
arglist = self.tree_node.get_super_arglist()
if arglist is None:

View File

@@ -488,7 +488,7 @@ class TypeVarFilter(object):
class _AbstractAnnotatedClass(ClassContext):
def get_type_var_filter(self):
return TypeVarFilter(self.get_given_types(), self.find_annotation_variables())
return TypeVarFilter(self.get_given_types(), self.list_type_vars())
def get_filters(self, search_global=False, *args, **kwargs):
for f in super(_AbstractAnnotatedClass, self).get_filters(search_global, *args, **kwargs):
@@ -530,7 +530,7 @@ class _AbstractAnnotatedClass(ClassContext):
return '<%s: %s%s>' % (
self.__class__.__name__,
self.name.string_name,
self.get_given_types(),
list(self.get_given_types()),
)
@to_list

View File

@@ -30,8 +30,9 @@ from jedi.evaluate.cache import evaluator_method_cache
from jedi.evaluate import compiled
from jedi.evaluate.base_context import NO_CONTEXTS, ContextSet
from jedi.evaluate.lazy_context import LazyTreeContext
from jedi.evaluate.context import ModuleContext
from jedi.evaluate.context.typing import TypeVar, AnnotatedClass, AnnotatedSubClass
from jedi.evaluate.context import ModuleContext, ClassContext
from jedi.evaluate.context.typing import TypeVar, AnnotatedClass, \
AnnotatedSubClass
from jedi.evaluate.helpers import is_string, execute_evaluated
from jedi import debug
from jedi import parser_utils
@@ -218,7 +219,7 @@ def infer_return_types(function_execution_context):
context = function_execution_context.function_context.get_default_param_context()
unknown_type_vars = list(find_unknown_type_vars(context, annotation))
if not unknown_type_vars:
return context.eval_node(annotation)
return context.eval_node(annotation).execute_annotation()
return define_type_vars_for_execution(
context.eval_node(annotation),
@@ -232,11 +233,11 @@ def define_type_vars_for_execution(to_define_contexts, execution_context,
all_annotations = py__annotations__(execution_context.tree_node)
return _define_type_vars(
to_define_contexts,
_infer_type_vars(execution_context, all_annotations),
_infer_type_vars_for_execution(execution_context, all_annotations),
)
def _infer_type_vars(execution_context, annotation_dict):
def _infer_type_vars_for_execution(execution_context, annotation_dict):
"""
Some functions use type vars that are not defined by the class, but rather
only defined in the function. See for example `iter`. In those cases we
@@ -263,49 +264,58 @@ def _infer_type_vars(execution_context, annotation_dict):
for ann in annotation_context_set:
_merge_type_var_dicts(
annotation_variable_results,
_unpack_type_vars(ann, actual_context_set),
_infer_type_vars(ann, actual_context_set),
)
return annotation_variable_results
def _define_type_vars(annotation_contexts, type_var_dict):
def remap_type_vars(type_var_contexts):
return ContextSet.from_sets(
type_var_dict.get(type_var, ContextSet(type_var))
for type_var in type_var_contexts
)
def remap_type_vars(cls):
for type_var in cls.list_type_vars():
yield type_var_dict.get(type_var.py__name__(), NO_CONTEXTS)
if not type_var_dict:
return annotation_contexts
context_set = ContextSet()
for annotation_context in annotation_contexts:
if isinstance(annotation_context, AnnotatedClass):
if isinstance(annotation_context, ClassContext):
context_set |= ContextSet.from_iterable([
AnnotatedSubClass(
annotation_context.evaluator,
annotation_context.parent_context,
annotation_context.tree_node,
tuple(remap_type_vars(tcs)
for tcs in annotation_context.get_given_types())
given_types=tuple(remap_type_vars(annotation_context))
)
])
return context_set
def _merge_type_var_dicts(base_dict, new_dict):
for type_var, contexts in new_dict.items():
for type_var_name, contexts in new_dict.items():
try:
base_dict[type_var] = contexts
base_dict[type_var_name] |= contexts
except KeyError:
base_dict[type_var] |= contexts
base_dict[type_var_name] = contexts
def _unpack_type_vars(annotation_context, context_set):
def _infer_type_vars(annotation_context, context_set):
"""
This function tries to find information about undefined type vars and
returns a dict from type var name to context set.
This is for example important to understand what `iter([1])` returns.
According to typeshed, `iter` returns an `Iterator[_T]`:
def iter(iterable: Iterable[_T]) -> Iterator[_T]: ...
This functions would generate `int` for `_T` in this case, because it
unpacks the `Iterable`.
"""
type_var_dict = {}
if isinstance(annotation_context, TypeVar):
return {annotation_context: context_set.py__class__()}
return {annotation_context.py__name__(): context_set.py__class__()}
elif isinstance(annotation_context, AnnotatedClass):
name = annotation_context.py__name__()
if name == 'Iterable':
@@ -314,11 +324,37 @@ def _unpack_type_vars(annotation_context, context_set):
for nested_annotation_context in given[0]:
_merge_type_var_dicts(
type_var_dict,
_unpack_type_vars(
_infer_type_vars(
nested_annotation_context,
context_set.merge_types_of_iterate()
)
)
elif name == 'Mapping':
given = annotation_context.get_given_types()
if len(given) == 2:
for context in context_set:
try:
method = context.get_mapping_item_contexts
except AttributeError:
continue
key_contexts, value_contexts = method()
for nested_annotation_context in given[0]:
_merge_type_var_dicts(
type_var_dict,
_infer_type_vars(
nested_annotation_context,
key_contexts,
)
)
for nested_annotation_context in given[1]:
_merge_type_var_dicts(
type_var_dict,
_infer_type_vars(
nested_annotation_context,
value_contexts,
)
)
return type_var_dict

View File

@@ -210,7 +210,7 @@ dic2 = {'asdf': 3, 'b': 'str'}
#? int()
dic2['asdf']
# TODO for now get doesn't work properly when used with a literal.
#? None
#? None int() str()
dic2.get('asdf')
# string literal