First implementation of doing precise if statement filtering.

This commit is contained in:
Dave Halter
2015-06-22 22:16:38 +02:00
parent 6da4f1fffb
commit 64fcbbba79
5 changed files with 114 additions and 3 deletions

View File

@@ -88,6 +88,7 @@ class Evaluator(object):
self.recursion_detector = recursion.RecursionDetector()
self.execution_recursion_detector = recursion.ExecutionRecursionDetector()
self.analysis = []
self.predefined_if_name_dict_dict = {}
def wrap(self, element):
if isinstance(element, tree.Class):
@@ -157,13 +158,78 @@ class Evaluator(object):
debug.dbg('eval_statement result %s', types)
return types
@memoize_default(evaluator_is_first_arg=True)
def eval_element(self, element):
if isinstance(element, iterable.AlreadyEvaluated):
return list(element)
elif isinstance(element, iterable.MergedNodes):
return iterable.unite(self.eval_element(e) for e in element)
parent = element.get_parent_until((tree.IfStmt, tree.IsScope))
predefined_if_name_dict = self.predefined_if_name_dict_dict.get(parent)
if not predefined_if_name_dict and isinstance(parent, tree.IfStmt):
if_stmt = parent.children[1]
name_dicts = [{}]
# If we already did a check, we don't want to do it again -> If
# predefined_if_name_dict_dict is filled, we stop.
# We don't want to check the if stmt itself, it's just about
# the content.
if element.start_pos > if_stmt.end_pos:
# Now we need to check if the names in the if_stmt match the
# names in the suite.
if_names = helpers.get_names_of_node(if_stmt)
element_names = helpers.get_names_of_node(element)
str_element_names = [str(e) for e in element_names]
if any(str(i) in str_element_names for i in if_names):
for if_name in if_names:
definitions = self.goto_definition(if_name)
# Every name that has multiple different definitions
# causes the complexity to rise. The complexity should
# never fall below 1.
if len(definitions) > 1:
if len(name_dicts) * len(definitions) > 16:
debug.dbg('Too many options for if branch evaluation %s.', if_stmt)
# There's only a certain amount of branches
# Jedi can evaluate, otherwise it will take to
# long.
name_dicts = [{}]
break
original_name_dicts = list(name_dicts)
name_dicts = []
for definition in definitions:
new_name_dicts = list(original_name_dicts)
for i, name_dict in enumerate(new_name_dicts):
new_name_dicts[i] = name_dict.copy()
new_name_dicts[i][str(if_name)] = [definition]
name_dicts += new_name_dicts
else:
for name_dict in name_dicts:
name_dict[str(if_name)] = definitions
if len(name_dicts) > 1:
print('XXXX', len(name_dicts), if_name, definitions)
if len(name_dicts) > 1:
result = []
for name_dict in name_dicts:
self.predefined_if_name_dict_dict[parent] = name_dict
try:
result += self._eval_element_not_cached(element)
finally:
del self.predefined_if_name_dict_dict[parent]
return result
else:
return self._eval_element_cached(element)
else:
if predefined_if_name_dict:
return self._eval_element_not_cached(element)
else:
return self._eval_element_cached(element)
@memoize_default(evaluator_is_first_arg=True)
def _eval_element_cached(self, element):
return self._eval_element_not_cached(element)
def _eval_element_not_cached(self, element):
debug.dbg('eval_element %s@%s', element, element.start_pos)
if isinstance(element, (tree.Name, tree.Literal)) or tree.is_node(element, 'atom'):
return self._eval_atom(element)

View File

@@ -79,11 +79,17 @@ class NameFinder(object):
self.scope = evaluator.wrap(scope)
self.name_str = name_str
self.position = position
self._found_predefined_if_name = None
@debug.increase_indent
def find(self, scopes, search_global=False):
# TODO rename scopes to names_dicts
names = self.filter_name(scopes)
if self._found_predefined_if_name is not None:
print('HAVE FOUND', self._found_predefined_if_name)
return self._found_predefined_if_name
types = self._names_to_types(names, search_global)
if not names and not types \
@@ -150,8 +156,29 @@ class NameFinder(object):
if isinstance(self.name_str, tree.Name):
origin_scope = self.name_str.get_parent_until(tree.Scope, reverse=True)
scope = self.name_str
check = None
while True:
scope = scope.parent
if isinstance(scope, tree.IsScope) or scope is None:
break
elif isinstance(scope, tree.IfStmt):
try:
name_dict = self._evaluator.predefined_if_name_dict_dict[scope]
types = name_dict[str(self.name_str)]
except KeyError:
continue
else:
check = flow_analysis.break_check(self._evaluator, self.scope,
origin_scope)
if check is flow_analysis.UNREACHABLE:
self._found_predefined_if_name = []
else:
self._found_predefined_if_name = types
break
else:
origin_scope = None
if isinstance(stmt.parent, compiled.CompiledObject):
# TODO seriously? this is stupid.
continue
@@ -159,8 +186,10 @@ class NameFinder(object):
stmt, origin_scope)
if check is not flow_analysis.UNREACHABLE:
last_names.append(name)
if check is flow_analysis.REACHABLE:
break
last_names.append(name)
if isinstance(name_scope, er.FunctionExecution):
# Replace params

View File

@@ -41,10 +41,12 @@ def break_check(evaluator, base_scope, stmt, origin_scope=None):
# e.g. `if 0:` would cause all name lookup within the flow make
# unaccessible. This is not a "problem" in Python, because the code is
# never called. In Jedi though, we still want to infer types.
"""
while origin_scope is not None:
if element_scope == origin_scope:
return REACHABLE
origin_scope = origin_scope.parent
"""
return _break_check(evaluator, stmt, base_scope, element_scope)
@@ -62,7 +64,8 @@ def _break_check(evaluator, stmt, base_scope, element_scope):
reachable = reachable.invert()
else:
node = element_scope.node_in_which_check_node(stmt)
reachable = _check_if(evaluator, node)
if node is not None:
reachable = _check_if(evaluator, node)
elif isinstance(element_scope, (tree.TryStmt, tree.WhileStmt)):
return UNSURE

View File

@@ -2,6 +2,7 @@ import copy
from itertools import chain
from jedi.parser import tree
from jedi import common
def deep_ast_copy(obj, parent=None, new_elements=None):
@@ -105,6 +106,18 @@ def call_of_name(name, cut_own_trailer=False):
return par
def get_names_of_node(node):
try:
children = node.children
except AttributeError:
if node.type == 'name':
return [node]
else:
return []
else:
return common.unite(get_names_of_node(c) for c in children)
def get_module_names(module, all_scopes):
"""
Returns a dictionary with name parts as keys and their call paths as

View File

@@ -362,7 +362,7 @@ class MergedArray(_FakeArray):
raise IndexError
def values(self):
return unite(*(a.values() for a in self._arrays))
return unite((a.values() for a in self._arrays))
def __iter__(self):
for array in self._arrays: