Fix a lot of list comprehensions.

This commit is contained in:
Dave Halter
2016-12-02 11:17:55 +01:00
parent dac372405e
commit 16a48a7a45
6 changed files with 70 additions and 33 deletions

View File

@@ -177,6 +177,9 @@ class Evaluator(object):
return types
def eval_element(self, context, element):
if isinstance(context, iterable.CompForContext):
return self._eval_element_not_cached(context, element)
if_stmt = element.get_parent_until((tree.IfStmt, tree.ForStmt, tree.IsScope))
predefined_if_name_dict = context.predefined_names.get(if_stmt)
if predefined_if_name_dict is None and isinstance(if_stmt, tree.IfStmt):

View File

@@ -72,6 +72,9 @@ class TreeContext(Context):
super(TreeContext, self).__init__(evaluator, parent_context)
self.predefined_names = {}
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self.get_node())
class FlowContext(TreeContext):
def get_parent_flow_context(self):

View File

@@ -107,7 +107,7 @@ class NameFinder(object):
else:
self._string_name = name_or_str
self._position = position
self._found_predefined_if_name = None
self._found_predefined_types = None
@debug.increase_indent
def find(self, filters, attribute_lookup):
@@ -118,8 +118,8 @@ class NameFinder(object):
# TODO rename scopes to names_dicts
names = self.filter_name(filters)
if self._found_predefined_if_name is not None:
return self._found_predefined_if_name
if self._found_predefined_types is not None:
return self._found_predefined_types
types = self._names_to_types(names, attribute_lookup)
@@ -201,7 +201,8 @@ class NameFinder(object):
check = None
while True:
scope = scope.parent
if scope.type in ("if_stmt", "for_stmt", "comp_for"):
if scope.type in ("if_stmt", "for_stmt"):
# TODO try removing for_stmt.
try:
name_dict = self.context.predefined_names[scope]
types = set(name_dict[self._string_name])
@@ -212,14 +213,14 @@ class NameFinder(object):
# It doesn't make any sense to check if
# statements in the if statement itself, just
# deliver types.
self._found_predefined_if_name = types
self._found_predefined_types = types
else:
check = flow_analysis.reachability_check(
self._context, self._context, origin_scope)
if check is flow_analysis.UNREACHABLE:
self._found_predefined_if_name = set()
self._found_predefined_types = set()
else:
self._found_predefined_if_name = types
self._found_predefined_types = types
break
if isinstance(scope, tree.IsScope) or scope is None:
break
@@ -260,7 +261,7 @@ class NameFinder(object):
except KeyError:
continue
else:
self._found_predefined_if_name = types
self._found_predefined_types = types
return []
for filter in filters:
names = filter.get(self._name)
@@ -354,7 +355,6 @@ class NameFinder(object):
return result
@memoize_default(set(), evaluator_is_first_arg=True)
def _name_to_types(evaluator, context, name):
types = []
node = name.get_definition()
@@ -367,9 +367,12 @@ def _name_to_types(evaluator, context, name):
if types:
return types
if node.type in ('for_stmt', 'comp_for'):
container_types = context.eval_node(node.children[3])
for_types = iterable.py__iter__types(evaluator, container_types, node.children[3])
types = check_tuple_assignments(evaluator, for_types, name)
try:
types = context.predefined_names[node][name.value]
except KeyError:
container_types = context.eval_node(node.children[3])
for_types = iterable.py__iter__types(evaluator, container_types, node.children[3])
types = check_tuple_assignments(evaluator, for_types, name)
elif isinstance(node, tree.Param):
return set() # TODO remove
types = _eval_param(evaluator, context, node)

View File

@@ -31,7 +31,8 @@ from jedi.evaluate.cache import memoize_default
from jedi.evaluate import analysis
from jedi.evaluate import pep0484
from jedi import common
from jedi.evaluate.filters import DictFilter, AbstractNameDefinition
from jedi.evaluate.filters import DictFilter, AbstractNameDefinition, \
ParserTreeFilter
from jedi.evaluate import context
from jedi.evaluate import precedence
@@ -164,6 +165,22 @@ class Generator(GeneratorMixin, context.Context):
return "<%s of %s>" % (type(self).__name__, self._func_execution_context)
class CompForContext(context.TreeContext):
@classmethod
def from_comp_for(cls, parent_context, comp_for):
return cls(parent_context.evaluator, parent_context, comp_for)
def __init__(self, evaluator, parent_context, comp_for):
super(CompForContext, self).__init__(evaluator, parent_context)
self.node = comp_for
def get_node(self):
return self.node
def get_filters(self, search_global, until_position=None, origin_scope=None):
yield ParserTreeFilter(self.evaluator, self, self.node)
class Comprehension(AbstractSequence):
@staticmethod
def from_atom(evaluator, context, atom):
@@ -192,13 +209,14 @@ class Comprehension(AbstractSequence):
# The atom contains a testlist_comp
return self._get_comprehension().children[1]
@memoize_default()
def _eval_node(self, index=0):
"""
The first part `x + 1` of the list comprehension:
[x + 1 for x in foo]
"""
return self._get_comprehension().children[index]
#TODO delete
comp_for = self._get_comp_for()
# For nested comprehensions we need to search the last one.
node = self._get_comprehension().children[index]
@@ -206,25 +224,37 @@ class Comprehension(AbstractSequence):
#TODO raise NotImplementedError('should not need to copy...')
return helpers.deep_ast_copy(node, parent=last_comp)
def _nested(self, comp_fors):
@memoize_default()
def _get_comp_for_context(self, parent_context, comp_for):
return CompForContext.from_comp_for(
parent_context,
comp_for,
)
def _nested(self, comp_fors, parent_context=None):
evaluator = self.evaluator
comp_for = comp_fors[0]
input_node = comp_for.children[3]
input_types = self._defining_context.eval_node(input_node)
parent_context = parent_context or self._defining_context
input_types = parent_context.eval_node(input_node)
iterated = py__iter__(evaluator, input_types, input_node)
exprlist = comp_for.children[1]
for i, lazy_context in enumerate(iterated):
types = lazy_context.infer()
dct = unpack_tuple_to_dict(evaluator, types, exprlist)
with helpers.predefine_names(self._defining_context, comp_for, dct):
context = self._get_comp_for_context(
parent_context,
comp_for,
)
with helpers.predefine_names(context, comp_for, dct):
try:
for result in self._nested(comp_fors[1:]):
for result in self._nested(comp_fors[1:], context):
yield result
except IndexError:
iterated = self._defining_context.eval_node(self._eval_node())
iterated = context.eval_node(self._eval_node())
if self.array_type == 'dict':
yield iterated, self._defining_context.eval_node(self._eval_node(2))
yield iterated, context.eval_node(self._eval_node(2))
else:
yield iterated

View File

@@ -184,9 +184,6 @@ class ClassContext(use_metaclass(CachedMetaClass, context.TreeContext)):
return names
return []
def __repr__(self):
return "<%s of %s>" % (self.__class__.__name__, self.classdef)
@property
def name(self):
return ContextName(self, self.classdef.name)
@@ -252,9 +249,6 @@ class FunctionContext(use_metaclass(CachedMetaClass, context.TreeContext)):
name = 'FUNCTION_CLASS'
return compiled.get_special_object(self.evaluator, name)
def __repr__(self):
return "<%s of %s>" % (self.__class__.__name__, self.base_func)
@property
def name(self):
return ContextName(self, self.funcdef.name)
@@ -411,9 +405,6 @@ class FunctionExecutionContext(Executed):
def get_params(self):
return param.get_params(self.evaluator, self.parent_context, self.funcdef, self.var_args)
def __repr__(self):
return "<%s of %s>" % (self.__class__.__name__, self.funcdef)
class AnonymousFunctionExecution(FunctionExecutionContext):
def __init__(self, evaluator, parent_context, funcdef):
@@ -605,6 +596,3 @@ class ModuleContext(use_metaclass(CachedMetaClass, context.TreeContext)):
def py__class__(self):
return compiled.get_special_object(self.evaluator, 'MODULE_CLASS')
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self.module_node)

View File

@@ -69,9 +69,17 @@ listen(['' for x in [1]])
# nested list comprehensions
# -----------------
b = [a for arr in [[1]] for a in arr]
b = [a for arr in [[1, 1.0]] for a in arr]
#? int()
b[0]
#? float()
b[1]
b = [arr for arr in [[1, 1.0]] for a in arr]
#? int()
b[0][0]
#? float()
b[1][1]
b = [a for arr in [[1]] if '' for a in arr if '']
#? int()
@@ -181,6 +189,8 @@ def x():
foo = [x for x in [1, '']][:1]
#? int()
foo[0]
#? str()
foo[1]
# -----------------
# In class