Fix a lot of list comprehensions.

This commit is contained in:
Dave Halter
2016-12-02 11:17:55 +01:00
parent dac372405e
commit 16a48a7a45
6 changed files with 70 additions and 33 deletions

View File

@@ -177,6 +177,9 @@ class Evaluator(object):
return types return types
def eval_element(self, context, element): def eval_element(self, context, element):
if isinstance(context, iterable.CompForContext):
return self._eval_element_not_cached(context, element)
if_stmt = element.get_parent_until((tree.IfStmt, tree.ForStmt, tree.IsScope)) if_stmt = element.get_parent_until((tree.IfStmt, tree.ForStmt, tree.IsScope))
predefined_if_name_dict = context.predefined_names.get(if_stmt) predefined_if_name_dict = context.predefined_names.get(if_stmt)
if predefined_if_name_dict is None and isinstance(if_stmt, tree.IfStmt): if predefined_if_name_dict is None and isinstance(if_stmt, tree.IfStmt):

View File

@@ -72,6 +72,9 @@ class TreeContext(Context):
super(TreeContext, self).__init__(evaluator, parent_context) super(TreeContext, self).__init__(evaluator, parent_context)
self.predefined_names = {} self.predefined_names = {}
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self.get_node())
class FlowContext(TreeContext): class FlowContext(TreeContext):
def get_parent_flow_context(self): def get_parent_flow_context(self):

View File

@@ -107,7 +107,7 @@ class NameFinder(object):
else: else:
self._string_name = name_or_str self._string_name = name_or_str
self._position = position self._position = position
self._found_predefined_if_name = None self._found_predefined_types = None
@debug.increase_indent @debug.increase_indent
def find(self, filters, attribute_lookup): def find(self, filters, attribute_lookup):
@@ -118,8 +118,8 @@ class NameFinder(object):
# TODO rename scopes to names_dicts # TODO rename scopes to names_dicts
names = self.filter_name(filters) names = self.filter_name(filters)
if self._found_predefined_if_name is not None: if self._found_predefined_types is not None:
return self._found_predefined_if_name return self._found_predefined_types
types = self._names_to_types(names, attribute_lookup) types = self._names_to_types(names, attribute_lookup)
@@ -201,7 +201,8 @@ class NameFinder(object):
check = None check = None
while True: while True:
scope = scope.parent scope = scope.parent
if scope.type in ("if_stmt", "for_stmt", "comp_for"): if scope.type in ("if_stmt", "for_stmt"):
# TODO try removing for_stmt.
try: try:
name_dict = self.context.predefined_names[scope] name_dict = self.context.predefined_names[scope]
types = set(name_dict[self._string_name]) types = set(name_dict[self._string_name])
@@ -212,14 +213,14 @@ class NameFinder(object):
# It doesn't make any sense to check if # It doesn't make any sense to check if
# statements in the if statement itself, just # statements in the if statement itself, just
# deliver types. # deliver types.
self._found_predefined_if_name = types self._found_predefined_types = types
else: else:
check = flow_analysis.reachability_check( check = flow_analysis.reachability_check(
self._context, self._context, origin_scope) self._context, self._context, origin_scope)
if check is flow_analysis.UNREACHABLE: if check is flow_analysis.UNREACHABLE:
self._found_predefined_if_name = set() self._found_predefined_types = set()
else: else:
self._found_predefined_if_name = types self._found_predefined_types = types
break break
if isinstance(scope, tree.IsScope) or scope is None: if isinstance(scope, tree.IsScope) or scope is None:
break break
@@ -260,7 +261,7 @@ class NameFinder(object):
except KeyError: except KeyError:
continue continue
else: else:
self._found_predefined_if_name = types self._found_predefined_types = types
return [] return []
for filter in filters: for filter in filters:
names = filter.get(self._name) names = filter.get(self._name)
@@ -354,7 +355,6 @@ class NameFinder(object):
return result return result
@memoize_default(set(), evaluator_is_first_arg=True)
def _name_to_types(evaluator, context, name): def _name_to_types(evaluator, context, name):
types = [] types = []
node = name.get_definition() node = name.get_definition()
@@ -367,9 +367,12 @@ def _name_to_types(evaluator, context, name):
if types: if types:
return types return types
if node.type in ('for_stmt', 'comp_for'): if node.type in ('for_stmt', 'comp_for'):
container_types = context.eval_node(node.children[3]) try:
for_types = iterable.py__iter__types(evaluator, container_types, node.children[3]) types = context.predefined_names[node][name.value]
types = check_tuple_assignments(evaluator, for_types, name) except KeyError:
container_types = context.eval_node(node.children[3])
for_types = iterable.py__iter__types(evaluator, container_types, node.children[3])
types = check_tuple_assignments(evaluator, for_types, name)
elif isinstance(node, tree.Param): elif isinstance(node, tree.Param):
return set() # TODO remove return set() # TODO remove
types = _eval_param(evaluator, context, node) types = _eval_param(evaluator, context, node)

View File

@@ -31,7 +31,8 @@ from jedi.evaluate.cache import memoize_default
from jedi.evaluate import analysis from jedi.evaluate import analysis
from jedi.evaluate import pep0484 from jedi.evaluate import pep0484
from jedi import common from jedi import common
from jedi.evaluate.filters import DictFilter, AbstractNameDefinition from jedi.evaluate.filters import DictFilter, AbstractNameDefinition, \
ParserTreeFilter
from jedi.evaluate import context from jedi.evaluate import context
from jedi.evaluate import precedence from jedi.evaluate import precedence
@@ -164,6 +165,22 @@ class Generator(GeneratorMixin, context.Context):
return "<%s of %s>" % (type(self).__name__, self._func_execution_context) return "<%s of %s>" % (type(self).__name__, self._func_execution_context)
class CompForContext(context.TreeContext):
@classmethod
def from_comp_for(cls, parent_context, comp_for):
return cls(parent_context.evaluator, parent_context, comp_for)
def __init__(self, evaluator, parent_context, comp_for):
super(CompForContext, self).__init__(evaluator, parent_context)
self.node = comp_for
def get_node(self):
return self.node
def get_filters(self, search_global, until_position=None, origin_scope=None):
yield ParserTreeFilter(self.evaluator, self, self.node)
class Comprehension(AbstractSequence): class Comprehension(AbstractSequence):
@staticmethod @staticmethod
def from_atom(evaluator, context, atom): def from_atom(evaluator, context, atom):
@@ -192,13 +209,14 @@ class Comprehension(AbstractSequence):
# The atom contains a testlist_comp # The atom contains a testlist_comp
return self._get_comprehension().children[1] return self._get_comprehension().children[1]
@memoize_default()
def _eval_node(self, index=0): def _eval_node(self, index=0):
""" """
The first part `x + 1` of the list comprehension: The first part `x + 1` of the list comprehension:
[x + 1 for x in foo] [x + 1 for x in foo]
""" """
return self._get_comprehension().children[index]
#TODO delete
comp_for = self._get_comp_for() comp_for = self._get_comp_for()
# For nested comprehensions we need to search the last one. # For nested comprehensions we need to search the last one.
node = self._get_comprehension().children[index] node = self._get_comprehension().children[index]
@@ -206,25 +224,37 @@ class Comprehension(AbstractSequence):
#TODO raise NotImplementedError('should not need to copy...') #TODO raise NotImplementedError('should not need to copy...')
return helpers.deep_ast_copy(node, parent=last_comp) return helpers.deep_ast_copy(node, parent=last_comp)
def _nested(self, comp_fors): @memoize_default()
def _get_comp_for_context(self, parent_context, comp_for):
return CompForContext.from_comp_for(
parent_context,
comp_for,
)
def _nested(self, comp_fors, parent_context=None):
evaluator = self.evaluator evaluator = self.evaluator
comp_for = comp_fors[0] comp_for = comp_fors[0]
input_node = comp_for.children[3] input_node = comp_for.children[3]
input_types = self._defining_context.eval_node(input_node) parent_context = parent_context or self._defining_context
input_types = parent_context.eval_node(input_node)
iterated = py__iter__(evaluator, input_types, input_node) iterated = py__iter__(evaluator, input_types, input_node)
exprlist = comp_for.children[1] exprlist = comp_for.children[1]
for i, lazy_context in enumerate(iterated): for i, lazy_context in enumerate(iterated):
types = lazy_context.infer() types = lazy_context.infer()
dct = unpack_tuple_to_dict(evaluator, types, exprlist) dct = unpack_tuple_to_dict(evaluator, types, exprlist)
with helpers.predefine_names(self._defining_context, comp_for, dct): context = self._get_comp_for_context(
parent_context,
comp_for,
)
with helpers.predefine_names(context, comp_for, dct):
try: try:
for result in self._nested(comp_fors[1:]): for result in self._nested(comp_fors[1:], context):
yield result yield result
except IndexError: except IndexError:
iterated = self._defining_context.eval_node(self._eval_node()) iterated = context.eval_node(self._eval_node())
if self.array_type == 'dict': if self.array_type == 'dict':
yield iterated, self._defining_context.eval_node(self._eval_node(2)) yield iterated, context.eval_node(self._eval_node(2))
else: else:
yield iterated yield iterated

View File

@@ -184,9 +184,6 @@ class ClassContext(use_metaclass(CachedMetaClass, context.TreeContext)):
return names return names
return [] return []
def __repr__(self):
return "<%s of %s>" % (self.__class__.__name__, self.classdef)
@property @property
def name(self): def name(self):
return ContextName(self, self.classdef.name) return ContextName(self, self.classdef.name)
@@ -252,9 +249,6 @@ class FunctionContext(use_metaclass(CachedMetaClass, context.TreeContext)):
name = 'FUNCTION_CLASS' name = 'FUNCTION_CLASS'
return compiled.get_special_object(self.evaluator, name) return compiled.get_special_object(self.evaluator, name)
def __repr__(self):
return "<%s of %s>" % (self.__class__.__name__, self.base_func)
@property @property
def name(self): def name(self):
return ContextName(self, self.funcdef.name) return ContextName(self, self.funcdef.name)
@@ -411,9 +405,6 @@ class FunctionExecutionContext(Executed):
def get_params(self): def get_params(self):
return param.get_params(self.evaluator, self.parent_context, self.funcdef, self.var_args) return param.get_params(self.evaluator, self.parent_context, self.funcdef, self.var_args)
def __repr__(self):
return "<%s of %s>" % (self.__class__.__name__, self.funcdef)
class AnonymousFunctionExecution(FunctionExecutionContext): class AnonymousFunctionExecution(FunctionExecutionContext):
def __init__(self, evaluator, parent_context, funcdef): def __init__(self, evaluator, parent_context, funcdef):
@@ -605,6 +596,3 @@ class ModuleContext(use_metaclass(CachedMetaClass, context.TreeContext)):
def py__class__(self): def py__class__(self):
return compiled.get_special_object(self.evaluator, 'MODULE_CLASS') return compiled.get_special_object(self.evaluator, 'MODULE_CLASS')
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self.module_node)

View File

@@ -69,9 +69,17 @@ listen(['' for x in [1]])
# nested list comprehensions # nested list comprehensions
# ----------------- # -----------------
b = [a for arr in [[1]] for a in arr] b = [a for arr in [[1, 1.0]] for a in arr]
#? int() #? int()
b[0] b[0]
#? float()
b[1]
b = [arr for arr in [[1, 1.0]] for a in arr]
#? int()
b[0][0]
#? float()
b[1][1]
b = [a for arr in [[1]] if '' for a in arr if ''] b = [a for arr in [[1]] if '' for a in arr if '']
#? int() #? int()
@@ -181,6 +189,8 @@ def x():
foo = [x for x in [1, '']][:1] foo = [x for x in [1, '']][:1]
#? int() #? int()
foo[0] foo[0]
#? str()
foo[1]
# ----------------- # -----------------
# In class # In class