Add a ContextSet.

This is not bug free yet, but it's going to be a good abstraction for a lot of small things.
This commit is contained in:
Dave Halter
2017-09-25 11:04:09 +02:00
parent a433ee7a7e
commit 5328d1e700
14 changed files with 257 additions and 190 deletions

2
jedi/common/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from jedi.common.context import ContextSet, NO_CONTEXTS, \
iterator_to_context_set

58
jedi/common/context.py Normal file
View File

@@ -0,0 +1,58 @@
class ContextSet(object):
def __init__(self, *args):
self._set = set(args)
@classmethod
def from_iterable(cls, iterable):
return cls.from_set(set(iterable))
@classmethod
def from_set(cls, set_):
self = cls()
self._set = set_
return self
@classmethod
def from_sets(cls, sets):
"""
Used to work with an iterable of set.
"""
aggregated = set()
for set_ in sets:
print(set_)
if isinstance(set_, ContextSet):
aggregated |= set_._set
else:
aggregated |= set_
return cls.from_set(aggregated)
def __or__(self, other):
return ContextSet.from_set(self._set | other._set)
def __iter__(self):
for element in self._set:
yield element
def __repr__(self):
return '%s(%s)' % (self.__class__.__name__, ', '.join(str(s) for s in self._set))
def filter(self, filter_func):
return ContextSet(filter(filter_func, self._set))
def __getattr__(self, name):
def mapper(*args, **kwargs):
return ContextSet.from_sets(
getattr(context, name)(*args, **kwargs)
for context in self._set
)
return mapper
def iterator_to_context_set(func):
def wrapper(*args, **kwargs):
return ContextSet.from_iterable(func(*args, **kwargs))
return wrapper
NO_CONTEXTS = ContextSet()

View File

@@ -83,6 +83,7 @@ from jedi.evaluate import pep0484
from jedi.evaluate.filters import TreeNameDefinition, ParamName from jedi.evaluate.filters import TreeNameDefinition, ParamName
from jedi.evaluate.instance import AnonymousInstance, BoundMethod from jedi.evaluate.instance import AnonymousInstance, BoundMethod
from jedi.evaluate.context import ContextualizedName, ContextualizedNode from jedi.evaluate.context import ContextualizedName, ContextualizedNode
from jedi.common import ContextSet, NO_CONTEXTS
from jedi import parser_utils from jedi import parser_utils
@@ -101,7 +102,7 @@ def _limit_context_infers(func):
evaluator.inferred_element_counts[n] += 1 evaluator.inferred_element_counts[n] += 1
if evaluator.inferred_element_counts[n] > 300: if evaluator.inferred_element_counts[n] > 300:
debug.warning('In context %s there were too many inferences.', n) debug.warning('In context %s there were too many inferences.', n)
return set() return NO_CONTEXTS
except KeyError: except KeyError:
evaluator.inferred_element_counts[n] = 1 evaluator.inferred_element_counts[n] = 1
return func(evaluator, context, *args, **kwargs) return func(evaluator, context, *args, **kwargs)
@@ -163,7 +164,7 @@ class Evaluator(object):
with recursion.execution_allowed(self, stmt) as allowed: with recursion.execution_allowed(self, stmt) as allowed:
if allowed or context.get_root_context() == self.BUILTINS: if allowed or context.get_root_context() == self.BUILTINS:
return self._eval_stmt(context, stmt, seek_name) return self._eval_stmt(context, stmt, seek_name)
return set() return NO_CONTEXTS
#@evaluator_function_cache(default=[]) #@evaluator_function_cache(default=[])
@debug.increase_indent @debug.increase_indent
@@ -178,11 +179,11 @@ class Evaluator(object):
""" """
debug.dbg('eval_statement %s (%s)', stmt, seek_name) debug.dbg('eval_statement %s (%s)', stmt, seek_name)
rhs = stmt.get_rhs() rhs = stmt.get_rhs()
types = self.eval_element(context, rhs) context_set = self.eval_element(context, rhs)
if seek_name: if seek_name:
c_node = ContextualizedName(context, seek_name) c_node = ContextualizedName(context, seek_name)
types = finder.check_tuple_assignments(self, c_node, types) context_set = finder.check_tuple_assignments(self, c_node, context_set)
first_operator = next(stmt.yield_operators(), None) first_operator = next(stmt.yield_operators(), None)
if first_operator not in ('=', None) and first_operator.type == 'operator': if first_operator not in ('=', None) and first_operator.type == 'operator':
@@ -194,7 +195,7 @@ class Evaluator(object):
name, position=stmt.start_pos, search_global=True) name, position=stmt.start_pos, search_global=True)
for_stmt = tree.search_ancestor(stmt, 'for_stmt') for_stmt = tree.search_ancestor(stmt, 'for_stmt')
if for_stmt is not None and for_stmt.type == 'for_stmt' and types \ if for_stmt is not None and for_stmt.type == 'for_stmt' and context_set \
and parser_utils.for_stmt_defines_one_name(for_stmt): and parser_utils.for_stmt_defines_one_name(for_stmt):
# Iterate through result and add the values, that's possible # Iterate through result and add the values, that's possible
# only in for loops without clutter, because they are # only in for loops without clutter, because they are
@@ -208,11 +209,11 @@ class Evaluator(object):
with helpers.predefine_names(context, for_stmt, dct): with helpers.predefine_names(context, for_stmt, dct):
t = self.eval_element(context, rhs) t = self.eval_element(context, rhs)
left = precedence.calculate(self, context, left, operator, t) left = precedence.calculate(self, context, left, operator, t)
types = left context_set = left
else: else:
types = precedence.calculate(self, context, left, operator, types) context_set = precedence.calculate(self, context, left, operator, context_set)
debug.dbg('eval_statement result %s', types) debug.dbg('eval_statement result %s', context_set)
return types return context_set
def eval_element(self, context, element): def eval_element(self, context, element):
if isinstance(context, iterable.CompForContext): if isinstance(context, iterable.CompForContext):
@@ -261,14 +262,14 @@ class Evaluator(object):
new_name_dicts = list(original_name_dicts) new_name_dicts = list(original_name_dicts)
for i, name_dict in enumerate(new_name_dicts): for i, name_dict in enumerate(new_name_dicts):
new_name_dicts[i] = name_dict.copy() new_name_dicts[i] = name_dict.copy()
new_name_dicts[i][if_name.value] = set([definition]) new_name_dicts[i][if_name.value] = ContextSet(definition)
name_dicts += new_name_dicts name_dicts += new_name_dicts
else: else:
for name_dict in name_dicts: for name_dict in name_dicts:
name_dict[if_name.value] = definitions name_dict[if_name.value] = definitions
if len(name_dicts) > 1: if len(name_dicts) > 1:
result = set() result = ContextSet()
for name_dict in name_dicts: for name_dict in name_dicts:
with helpers.predefine_names(context, if_stmt, name_dict): with helpers.predefine_names(context, if_stmt, name_dict):
result |= self._eval_element_not_cached(context, element) result |= self._eval_element_not_cached(context, element)
@@ -293,7 +294,7 @@ class Evaluator(object):
return self._eval_element_not_cached(context, element) return self._eval_element_not_cached(context, element)
return self._eval_element_cached(context, element) return self._eval_element_cached(context, element)
@evaluator_function_cache(default=set()) @evaluator_function_cache(default=NO_CONTEXTS)
def _eval_element_cached(self, context, element): def _eval_element_cached(self, context, element):
return self._eval_element_not_cached(context, element) return self._eval_element_not_cached(context, element)
@@ -301,63 +302,66 @@ class Evaluator(object):
@_limit_context_infers @_limit_context_infers
def _eval_element_not_cached(self, context, element): def _eval_element_not_cached(self, context, element):
debug.dbg('eval_element %s@%s', element, element.start_pos) debug.dbg('eval_element %s@%s', element, element.start_pos)
types = set()
typ = element.type typ = element.type
if typ in ('name', 'number', 'string', 'atom'): if typ in ('name', 'number', 'string', 'atom'):
types = self.eval_atom(context, element) return self.eval_atom(context, element)
elif typ == 'keyword': elif typ == 'keyword':
# For False/True/None # For False/True/None
if element.value in ('False', 'True', 'None'): if element.value in ('False', 'True', 'None'):
types.add(compiled.builtin_from_name(self, element.value)) return ContextSet(compiled.builtin_from_name(self, element.value))
# else: print e.g. could be evaluated like this in Python 2.7 # else: print e.g. could be evaluated like this in Python 2.7
return NO_CONTEXTS
elif typ == 'lambdef': elif typ == 'lambdef':
types = set([er.FunctionContext(self, context, element)]) return ContextSet(er.FunctionContext(self, context, element))
elif typ == 'expr_stmt': elif typ == 'expr_stmt':
types = self.eval_statement(context, element) return self.eval_statement(context, element)
elif typ in ('power', 'atom_expr'): elif typ in ('power', 'atom_expr'):
first_child = element.children[0] first_child = element.children[0]
if not (first_child.type == 'keyword' and first_child.value == 'await'): if not (first_child.type == 'keyword' and first_child.value == 'await'):
types = self.eval_atom(context, first_child) context_set = self.eval_atom(context, first_child)
for trailer in element.children[1:]: for trailer in element.children[1:]:
if trailer == '**': # has a power operation. if trailer == '**': # has a power operation.
right = self.eval_element(context, element.children[2]) right = self.eval_element(context, element.children[2])
types = set(precedence.calculate(self, context, types, trailer, right)) context_set = precedence.calculate(
self,
context,
context_set,
trailer,
right
)
break break
types = self.eval_trailer(context, types, trailer) context_set = self.eval_trailer(context, context_set, trailer)
return context_set
elif typ in ('testlist_star_expr', 'testlist',): elif typ in ('testlist_star_expr', 'testlist',):
# The implicit tuple in statements. # The implicit tuple in statements.
types = set([iterable.SequenceLiteralContext(self, context, element)]) return ContextSet(iterable.SequenceLiteralContext(self, context, element))
elif typ in ('not_test', 'factor'): elif typ in ('not_test', 'factor'):
types = self.eval_element(context, element.children[-1]) context_set = self.eval_element(context, element.children[-1])
for operator in element.children[:-1]: for operator in element.children[:-1]:
types = set(precedence.factor_calculate(self, types, operator)) context_set = precedence.factor_calculate(self, context_set, operator)
return context_set
elif typ == 'test': elif typ == 'test':
# `x if foo else y` case. # `x if foo else y` case.
types = (self.eval_element(context, element.children[0]) | return (self.eval_element(context, element.children[0]) |
self.eval_element(context, element.children[-1])) self.eval_element(context, element.children[-1]))
elif typ == 'operator': elif typ == 'operator':
# Must be an ellipsis, other operators are not evaluated. # Must be an ellipsis, other operators are not evaluated.
# In Python 2 ellipsis is coded as three single dot tokens, not # In Python 2 ellipsis is coded as three single dot tokens, not
# as one token 3 dot token. # as one token 3 dot token.
assert element.value in ('.', '...') assert element.value in ('.', '...')
types = set([compiled.create(self, Ellipsis)]) return ContextSet(compiled.create(self, Ellipsis))
elif typ == 'dotted_name': elif typ == 'dotted_name':
types = self.eval_atom(context, element.children[0]) context_set = self.eval_atom(context, element.children[0])
for next_name in element.children[2::2]: for next_name in element.children[2::2]:
# TODO add search_global=True? # TODO add search_global=True?
types = unite( context_set.py__getattribute__(next_name, name_context=context)
typ.py__getattribute__(next_name, name_context=context) return context_set
for typ in types
)
types = types
elif typ == 'eval_input': elif typ == 'eval_input':
types = self._eval_element_not_cached(context, element.children[0]) return self._eval_element_not_cached(context, element.children[0])
elif typ == 'annassign': elif typ == 'annassign':
types = pep0484._evaluate_for_annotation(context, element.children[1]) return pep0484._evaluate_for_annotation(context, element.children[1])
else: else:
types = precedence.calculate_children(self, context, element.children) return precedence.calculate_children(self, context, element.children)
debug.dbg('eval_element result %s', types)
return types
def eval_atom(self, context, atom): def eval_atom(self, context, atom):
""" """
@@ -377,18 +381,19 @@ class Evaluator(object):
position=stmt.start_pos, position=stmt.start_pos,
search_global=True search_global=True
) )
elif isinstance(atom, tree.Literal): elif isinstance(atom, tree.Literal):
string = parser_utils.safe_literal_eval(atom.value) string = parser_utils.safe_literal_eval(atom.value)
return set([compiled.create(self, string)]) return ContextSet(compiled.create(self, string))
else: else:
c = atom.children c = atom.children
if c[0].type == 'string': if c[0].type == 'string':
# Will be one string. # Will be one string.
types = self.eval_atom(context, c[0]) context_set = self.eval_atom(context, c[0])
for string in c[1:]: for string in c[1:]:
right = self.eval_atom(context, string) right = self.eval_atom(context, string)
types = precedence.calculate(self, context, types, '+', right) context_set = precedence.calculate(self, context, context_set, '+', right)
return types return context_set
# Parentheses without commas are not tuples. # Parentheses without commas are not tuples.
elif c[0] == '(' and not len(c) == 2 \ elif c[0] == '(' and not len(c) == 2 \
and not(c[1].type == 'testlist_comp' and and not(c[1].type == 'testlist_comp' and
@@ -408,7 +413,7 @@ class Evaluator(object):
pass pass
if comp_for.type == 'comp_for': if comp_for.type == 'comp_for':
return set([iterable.Comprehension.from_atom(self, context, atom)]) return ContextSet(iterable.Comprehension.from_atom(self, context, atom))
# It's a dict/list/tuple literal. # It's a dict/list/tuple literal.
array_node = c[1] array_node = c[1]
@@ -420,28 +425,28 @@ class Evaluator(object):
context = iterable.DictLiteralContext(self, context, atom) context = iterable.DictLiteralContext(self, context, atom)
else: else:
context = iterable.SequenceLiteralContext(self, context, atom) context = iterable.SequenceLiteralContext(self, context, atom)
return set([context]) return ContextSet(context)
def eval_trailer(self, context, types, trailer): def eval_trailer(self, context, types, trailer):
trailer_op, node = trailer.children[:2] trailer_op, node = trailer.children[:2]
if node == ')': # `arglist` is optional. if node == ')': # `arglist` is optional.
node = () node = ()
new_types = set()
if trailer_op == '[': if trailer_op == '[':
new_types |= iterable.py__getitem__(self, context, types, trailer) return ContextSet(iterable.py__getitem__(self, context, types, trailer))
else: else:
context_set = ContextSet()
for typ in types: for typ in types:
debug.dbg('eval_trailer: %s in scope %s', trailer, typ) debug.dbg('eval_trailer: %s in scope %s', trailer, typ)
if trailer_op == '.': if trailer_op == '.':
new_types |= typ.py__getattribute__( context_set |= typ.py__getattribute__(
name_context=context, name_context=context,
name_or_str=node name_or_str=node
) )
elif trailer_op == '(': elif trailer_op == '(':
arguments = param.TreeArguments(self, context, node, trailer) arguments = param.TreeArguments(self, context, node, trailer)
new_types |= self.execute(typ, arguments) context_set |= self.execute(typ, arguments)
return new_types return context_set
@debug.increase_indent @debug.increase_indent
def execute(self, obj, arguments): def execute(self, obj, arguments):
@@ -460,11 +465,11 @@ class Evaluator(object):
func = obj.py__call__ func = obj.py__call__
except AttributeError: except AttributeError:
debug.warning("no execution possible %s", obj) debug.warning("no execution possible %s", obj)
return set() return NO_CONTEXTS
else: else:
types = func(arguments) context_set = func(arguments)
debug.dbg('execute result: %s in %s', types, obj) debug.dbg('execute result: %s in %s', context_set, obj)
return types return context_set
def goto_definitions(self, context, name): def goto_definitions(self, context, name):
def_ = name.get_definition(import_name_always=True) def_ = name.get_definition(import_name_always=True)
@@ -509,25 +514,25 @@ class Evaluator(object):
return module_names return module_names
par = name.parent par = name.parent
typ = par.type node_type = par.type
if typ == 'argument' and par.children[1] == '=' and par.children[0] == name: if node_type == 'argument' and par.children[1] == '=' and par.children[0] == name:
# Named param goto. # Named param goto.
trailer = par.parent trailer = par.parent
if trailer.type == 'arglist': if trailer.type == 'arglist':
trailer = trailer.parent trailer = trailer.parent
if trailer.type != 'classdef': if trailer.type != 'classdef':
if trailer.type == 'decorator': if trailer.type == 'decorator':
types = self.eval_element(context, trailer.children[1]) context_set = self.eval_element(context, trailer.children[1])
else: else:
i = trailer.parent.children.index(trailer) i = trailer.parent.children.index(trailer)
to_evaluate = trailer.parent.children[:i] to_evaluate = trailer.parent.children[:i]
types = self.eval_element(context, to_evaluate[0]) context_set = self.eval_element(context, to_evaluate[0])
for trailer in to_evaluate[1:]: for trailer in to_evaluate[1:]:
types = self.eval_trailer(context, types, trailer) context_set = self.eval_trailer(context, context_set, trailer)
param_names = [] param_names = []
for typ in types: for context in context_set:
try: try:
get_param_names = typ.get_param_names get_param_names = context.get_param_names
except AttributeError: except AttributeError:
pass pass
else: else:
@@ -535,7 +540,7 @@ class Evaluator(object):
if param_name.string_name == name.value: if param_name.string_name == name.value:
param_names.append(param_name) param_names.append(param_name)
return param_names return param_names
elif typ == 'dotted_name': # Is a decorator. elif node_type == 'dotted_name': # Is a decorator.
index = par.children.index(name) index = par.children.index(name)
if index > 0: if index > 0:
new_dotted = helpers.deep_ast_copy(par) new_dotted = helpers.deep_ast_copy(par)
@@ -546,7 +551,7 @@ class Evaluator(object):
for value in values for value in values
) )
if typ == 'trailer' and par.children[0] == '.': if node_type == 'trailer' and par.children[0] == '.':
values = helpers.evaluate_call_of_leaf(context, name, cut_own_trailer=True) values = helpers.evaluate_call_of_leaf(context, name, cut_own_trailer=True)
return unite( return unite(
value.py__getattribute__(name, name_context=context, is_goto=True) value.py__getattribute__(name, name_context=context, is_goto=True)

View File

@@ -15,6 +15,7 @@ from jedi.evaluate.filters import AbstractFilter, AbstractNameDefinition, \
ContextNameMixin ContextNameMixin
from jedi.evaluate.context import Context, LazyKnownContext from jedi.evaluate.context import Context, LazyKnownContext
from jedi.evaluate.compiled.getattr_static import getattr_static from jedi.evaluate.compiled.getattr_static import getattr_static
from jedi.common import ContextSet
from . import fake from . import fake
@@ -83,9 +84,9 @@ class CompiledObject(Context):
def py__call__(self, params): def py__call__(self, params):
if inspect.isclass(self.obj): if inspect.isclass(self.obj):
from jedi.evaluate.instance import CompiledInstance from jedi.evaluate.instance import CompiledInstance
return set([CompiledInstance(self.evaluator, self.parent_context, self, params)]) return ContextSet(CompiledInstance(self.evaluator, self.parent_context, self, params))
else: else:
return set(self._execute_function(params)) return ContextSet.from_iterable(self._execute_function(params))
@CheckAttribute @CheckAttribute
def py__class__(self): def py__class__(self):
@@ -221,9 +222,9 @@ class CompiledObject(Context):
def py__getitem__(self, index): def py__getitem__(self, index):
if type(self.obj) not in (str, list, tuple, unicode, bytes, bytearray, dict): if type(self.obj) not in (str, list, tuple, unicode, bytes, bytearray, dict):
# Get rid of side effects, we won't call custom `__getitem__`s. # Get rid of side effects, we won't call custom `__getitem__`s.
return set() return ContextSet()
return set([create(self.evaluator, self.obj[index])]) return ContextSet(create(self.evaluator, self.obj[index]))
@CheckAttribute @CheckAttribute
def py__iter__(self): def py__iter__(self):
@@ -278,7 +279,7 @@ class CompiledObject(Context):
return [] # Builtins don't have imports return [] # Builtins don't have imports
def dict_values(self): def dict_values(self):
return set(create(self.evaluator, v) for v in self.obj.values()) return ContextSet(create(self.evaluator, v) for v in self.obj.values())
class CompiledName(AbstractNameDefinition): class CompiledName(AbstractNameDefinition):
@@ -301,7 +302,9 @@ class CompiledName(AbstractNameDefinition):
@underscore_memoization @underscore_memoization
def infer(self): def infer(self):
module = self.parent_context.get_root_context() module = self.parent_context.get_root_context()
return [_create_from_name(self._evaluator, module, self.parent_context, self.string_name)] return ContextSet(_create_from_name(
self._evaluator, module, self.parent_context, self.string_name
))
class SignatureParamName(AbstractNameDefinition): class SignatureParamName(AbstractNameDefinition):
@@ -318,13 +321,13 @@ class SignatureParamName(AbstractNameDefinition):
def infer(self): def infer(self):
p = self._signature_param p = self._signature_param
evaluator = self.parent_context.evaluator evaluator = self.parent_context.evaluator
types = set() contexts = ContextSet()
if p.default is not p.empty: if p.default is not p.empty:
types.add(create(evaluator, p.default)) contexts = ContextSet(create(evaluator, p.default))
if p.annotation is not p.empty: if p.annotation is not p.empty:
annotation = create(evaluator, p.annotation) annotation = create(evaluator, p.annotation)
types |= annotation.execute_evaluated() contexts |= annotation.execute_evaluated()
return types return contexts
class UnresolvableParamName(AbstractNameDefinition): class UnresolvableParamName(AbstractNameDefinition):
@@ -335,7 +338,7 @@ class UnresolvableParamName(AbstractNameDefinition):
self.string_name = name self.string_name = name
def infer(self): def infer(self):
return set() return ContextSet()
class CompiledContextName(ContextNameMixin, AbstractNameDefinition): class CompiledContextName(ContextNameMixin, AbstractNameDefinition):
@@ -356,7 +359,7 @@ class EmptyCompiledName(AbstractNameDefinition):
self.string_name = name self.string_name = name
def infer(self): def infer(self):
return [] return ContextSet()
class CompiledObjectFilter(AbstractFilter): class CompiledObjectFilter(AbstractFilter):

View File

@@ -1,7 +1,7 @@
from jedi._compatibility import Python3Method from jedi._compatibility import Python3Method
from jedi.evaluate.utils import unite
from parso.python.tree import ExprStmt, CompFor from parso.python.tree import ExprStmt, CompFor
from jedi.parser_utils import clean_scope_docstring, get_doc_with_call_signature from jedi.parser_utils import clean_scope_docstring, get_doc_with_call_signature
from jedi.common import ContextSet, NO_CONTEXTS
class Context(object): class Context(object):
@@ -111,11 +111,11 @@ class AbstractLazyContext(object):
class LazyKnownContext(AbstractLazyContext): class LazyKnownContext(AbstractLazyContext):
"""data is a context.""" """data is a context."""
def infer(self): def infer(self):
return set([self.data]) return ContextSet(self.data)
class LazyKnownContexts(AbstractLazyContext): class LazyKnownContexts(AbstractLazyContext):
"""data is a set of contexts.""" """data is a ContextSet."""
def infer(self): def infer(self):
return self.data return self.data
@@ -125,7 +125,7 @@ class LazyUnknownContext(AbstractLazyContext):
super(LazyUnknownContext, self).__init__(None) super(LazyUnknownContext, self).__init__(None)
def infer(self): def infer(self):
return set() return NO_CONTEXTS
class LazyTreeContext(AbstractLazyContext): class LazyTreeContext(AbstractLazyContext):
@@ -155,7 +155,7 @@ def get_merged_lazy_context(lazy_contexts):
class MergedLazyContexts(AbstractLazyContext): class MergedLazyContexts(AbstractLazyContext):
"""data is a list of lazy contexts.""" """data is a list of lazy contexts."""
def infer(self): def infer(self):
return unite(l.infer() for l in self.data) return ContextSet.from_sets(l.infer() for l in self.data)
class ContextualizedNode(object): class ContextualizedNode(object):

View File

@@ -25,6 +25,7 @@ from jedi.evaluate.utils import unite, indent_block
from jedi.evaluate import context from jedi.evaluate import context
from jedi.evaluate.cache import evaluator_method_cache from jedi.evaluate.cache import evaluator_method_cache
from jedi.evaluate.iterable import SequenceLiteralContext, FakeSequence from jedi.evaluate.iterable import SequenceLiteralContext, FakeSequence
from jedi.common import iterator_to_context_set, ContextSet, NO_CONTEXTS
DOCSTRING_PARAM_PATTERNS = [ DOCSTRING_PARAM_PATTERNS = [
@@ -245,7 +246,7 @@ def infer_param(execution_context, param):
from jedi.evaluate.instance import AnonymousInstanceFunctionExecution from jedi.evaluate.instance import AnonymousInstanceFunctionExecution
def eval_docstring(docstring): def eval_docstring(docstring):
return set( return ContextSet.from_iterable(
p p
for param_str in _search_param_in_docstr(docstring, param.name.value) for param_str in _search_param_in_docstr(docstring, param.name.value)
for p in _evaluate_for_statement_string(module_context, param_str) for p in _evaluate_for_statement_string(module_context, param_str)
@@ -253,7 +254,7 @@ def infer_param(execution_context, param):
module_context = execution_context.get_root_context() module_context = execution_context.get_root_context()
func = param.get_parent_function() func = param.get_parent_function()
if func.type == 'lambdef': if func.type == 'lambdef':
return set() return NO_CONTEXTS
types = eval_docstring(execution_context.py__doc__()) types = eval_docstring(execution_context.py__doc__())
if isinstance(execution_context, AnonymousInstanceFunctionExecution) and \ if isinstance(execution_context, AnonymousInstanceFunctionExecution) and \
@@ -265,6 +266,7 @@ def infer_param(execution_context, param):
@evaluator_method_cache() @evaluator_method_cache()
@iterator_to_context_set
def infer_return_types(function_context): def infer_return_types(function_context):
def search_return_in_docstr(code): def search_return_in_docstr(code):
for p in DOCSTRING_RETURN_PATTERNS: for p in DOCSTRING_RETURN_PATTERNS:
@@ -278,4 +280,3 @@ def infer_return_types(function_context):
for type_str in search_return_in_docstr(function_context.py__doc__()): for type_str in search_return_in_docstr(function_context.py__doc__()):
for type_eval in _evaluate_for_statement_string(function_context.get_root_context(), type_str): for type_eval in _evaluate_for_statement_string(function_context.get_root_context(), type_str):
yield type_eval yield type_eval

View File

@@ -32,6 +32,7 @@ from jedi.evaluate import param
from jedi.evaluate import helpers from jedi.evaluate import helpers
from jedi.evaluate.filters import get_global_filters, TreeNameDefinition from jedi.evaluate.filters import get_global_filters, TreeNameDefinition
from jedi.evaluate.context import ContextualizedName, ContextualizedNode from jedi.evaluate.context import ContextualizedName, ContextualizedNode
from jedi.common import ContextSet
from jedi.parser_utils import is_scope, get_parent_scope from jedi.parser_utils import is_scope, get_parent_scope
@@ -62,7 +63,7 @@ class NameFinder(object):
check = flow_analysis.reachability_check( check = flow_analysis.reachability_check(
self._context, self._context.tree_node, self._name) self._context, self._context.tree_node, self._name)
if check is flow_analysis.UNREACHABLE: if check is flow_analysis.UNREACHABLE:
return set() return ContextSet()
return self._found_predefined_types return self._found_predefined_types
types = self._names_to_types(names, attribute_lookup) types = self._names_to_types(names, attribute_lookup)
@@ -158,22 +159,20 @@ class NameFinder(object):
return inst.execute_function_slots(names, name) return inst.execute_function_slots(names, name)
def _names_to_types(self, names, attribute_lookup): def _names_to_types(self, names, attribute_lookup):
types = set() contexts = ContextSet.from_sets(name.infer() for name in names)
types = unite(name.infer() for name in names) debug.dbg('finder._names_to_types: %s -> %s', names, contexts)
debug.dbg('finder._names_to_types: %s -> %s', names, types)
if not names and isinstance(self._context, AbstractInstanceContext): if not names and isinstance(self._context, AbstractInstanceContext):
# handling __getattr__ / __getattribute__ # handling __getattr__ / __getattribute__
return self._check_getattr(self._context) return self._check_getattr(self._context)
# Add isinstance and other if/assert knowledge. # Add isinstance and other if/assert knowledge.
if not types and isinstance(self._name, tree.Name) and \ if not contexts and isinstance(self._name, tree.Name) and \
not isinstance(self._name_context, AbstractInstanceContext): not isinstance(self._name_context, AbstractInstanceContext):
flow_scope = self._name flow_scope = self._name
base_node = self._name_context.tree_node base_node = self._name_context.tree_node
if base_node.type == 'comp_for': if base_node.type == 'comp_for':
return types return contexts
while True: while True:
flow_scope = get_parent_scope(flow_scope, include_flows=True) flow_scope = get_parent_scope(flow_scope, include_flows=True)
n = _check_flow_information(self._name_context, flow_scope, n = _check_flow_information(self._name_context, flow_scope,
@@ -182,7 +181,7 @@ class NameFinder(object):
return n return n
if flow_scope == base_node: if flow_scope == base_node:
break break
return types return contexts
def _name_to_types(evaluator, context, tree_name): def _name_to_types(evaluator, context, tree_name):
@@ -210,6 +209,7 @@ def _name_to_types(evaluator, context, tree_name):
types = pep0484.find_type_from_comment_hint_with(context, node, tree_name) types = pep0484.find_type_from_comment_hint_with(context, node, tree_name)
if types: if types:
return types return types
if typ in ('for_stmt', 'comp_for'): if typ in ('for_stmt', 'comp_for'):
try: try:
types = context.predefined_names[node][tree_name.value] types = context.predefined_names[node][tree_name.value]
@@ -222,11 +222,8 @@ def _name_to_types(evaluator, context, tree_name):
types = _remove_statements(evaluator, context, node, tree_name) types = _remove_statements(evaluator, context, node, tree_name)
elif typ == 'with_stmt': elif typ == 'with_stmt':
context_managers = context.eval_node(node.get_test_node_from_name(tree_name)) context_managers = context.eval_node(node.get_test_node_from_name(tree_name))
enter_methods = unite( enter_methods = context_managers.py__getattribute__('__enter__')
context_manager.py__getattribute__('__enter__') return enter_methods.execute_evaluated()
for context_manager in context_managers
)
types = unite(method.execute_evaluated() for method in enter_methods)
elif typ in ('import_from', 'import_name'): elif typ in ('import_from', 'import_name'):
types = imports.infer_import(context, tree_name) types = imports.infer_import(context, tree_name)
elif typ in ('funcdef', 'classdef'): elif typ in ('funcdef', 'classdef'):
@@ -262,7 +259,7 @@ def _apply_decorators(evaluator, context, node):
parent_context=context, parent_context=context,
funcdef=node funcdef=node
) )
initial = values = set([decoratee_context]) initial = values = ContextSet(decoratee_context)
for dec in reversed(node.get_decorators()): for dec in reversed(node.get_decorators()):
debug.dbg('decorator: %s %s', dec, values) debug.dbg('decorator: %s %s', dec, values)
dec_values = context.eval_node(dec.children[1]) dec_values = context.eval_node(dec.children[1])
@@ -294,20 +291,12 @@ def _remove_statements(evaluator, context, stmt, name):
Due to lazy evaluation, statements like a = func; b = a; b() have to be Due to lazy evaluation, statements like a = func; b = a; b() have to be
evaluated. evaluated.
""" """
types = set() pep0484_contexts = \
check_instance = None
pep0484types = \
pep0484.find_type_from_comment_hint_assign(context, stmt, name) pep0484.find_type_from_comment_hint_assign(context, stmt, name)
if pep0484types: if pep0484_contexts:
return pep0484types return pep0484_contexts
types |= context.eval_stmt(stmt, seek_name=name)
if check_instance is not None: return context.eval_stmt(stmt, seek_name=name)
# class renames
types = set([er.get_instance_el(evaluator, check_instance, a, True)
if isinstance(a, er.Function) else a for a in types])
return types
def _check_flow_information(context, flow, search_name, pos): def _check_flow_information(context, flow, search_name, pos):
@@ -377,26 +366,26 @@ def _check_isinstance_type(context, element, search_name):
except AssertionError: except AssertionError:
return None return None
result = set() context_set = ContextSet()
for cls_or_tup in lazy_context_cls.infer(): for cls_or_tup in lazy_context_cls.infer():
if isinstance(cls_or_tup, iterable.AbstractSequence) and \ if isinstance(cls_or_tup, iterable.AbstractSequence) and \
cls_or_tup.array_type == 'tuple': cls_or_tup.array_type == 'tuple':
for lazy_context in cls_or_tup.py__iter__(): for lazy_context in cls_or_tup.py__iter__():
for context in lazy_context.infer(): for context in lazy_context.infer():
result |= context.execute_evaluated() context_set |= context.execute_evaluated()
else: else:
result |= cls_or_tup.execute_evaluated() context_set |= cls_or_tup.execute_evaluated()
return result return context_set
def check_tuple_assignments(evaluator, contextualized_name, types): def check_tuple_assignments(evaluator, contextualized_name, context_set):
""" """
Checks if tuples are assigned. Checks if tuples are assigned.
""" """
lazy_context = None lazy_context = None
for index, node in contextualized_name.assignment_indexes(): for index, node in contextualized_name.assignment_indexes():
cn = ContextualizedNode(contextualized_name.context, node) cn = ContextualizedNode(contextualized_name.context, node)
iterated = iterable.py__iter__(evaluator, types, cn) iterated = iterable.py__iter__(evaluator, context_set, cn)
for _ in range(index + 1): for _ in range(index + 1):
try: try:
lazy_context = next(iterated) lazy_context = next(iterated)
@@ -405,6 +394,6 @@ def check_tuple_assignments(evaluator, contextualized_name, types):
# would allow this loop to run for a very long time if the # would allow this loop to run for a very long time if the
# index number is high. Therefore break if the loop is # index number is high. Therefore break if the loop is
# finished. # finished.
return set() return ContextSet()
types = lazy_context.infer() context_set = lazy_context.infer()
return types return context_set

View File

@@ -31,11 +31,12 @@ from jedi.evaluate import compiled
from jedi.evaluate import analysis from jedi.evaluate import analysis
from jedi.evaluate.cache import evaluator_method_cache from jedi.evaluate.cache import evaluator_method_cache
from jedi.evaluate.filters import AbstractNameDefinition from jedi.evaluate.filters import AbstractNameDefinition
from jedi.common import ContextSet, NO_CONTEXTS
# This memoization is needed, because otherwise we will infinitely loop on # This memoization is needed, because otherwise we will infinitely loop on
# certain imports. # certain imports.
@evaluator_method_cache(default=set()) @evaluator_method_cache(default=NO_CONTEXTS)
def infer_import(context, tree_name, is_goto=False): def infer_import(context, tree_name, is_goto=False):
module_context = context.get_root_context() module_context = context.get_root_context()
import_node = search_ancestor(tree_name, 'import_name', 'import_from') import_node = search_ancestor(tree_name, 'import_name', 'import_from')
@@ -63,7 +64,7 @@ def infer_import(context, tree_name, is_goto=False):
# scopes = [NestedImportModule(module, import_node)] # scopes = [NestedImportModule(module, import_node)]
if not types: if not types:
return set() return NO_CONTEXTS
if from_import_name is not None: if from_import_name is not None:
types = unite( types = unite(
@@ -270,7 +271,7 @@ class Importer(object):
def follow(self): def follow(self):
if not self.import_path: if not self.import_path:
return set() return NO_CONTEXTS
return self._do_import(self.import_path, self.sys_path_with_modifications()) return self._do_import(self.import_path, self.sys_path_with_modifications())
def _do_import(self, import_path, sys_path): def _do_import(self, import_path, sys_path):
@@ -296,7 +297,7 @@ class Importer(object):
module_name = '.'.join(import_parts) module_name = '.'.join(import_parts)
try: try:
return set([self._evaluator.modules[module_name]]) return ContextSet(self._evaluator.modules[module_name])
except KeyError: except KeyError:
pass pass
@@ -305,7 +306,7 @@ class Importer(object):
# the module cache. # the module cache.
bases = self._do_import(import_path[:-1], sys_path) bases = self._do_import(import_path[:-1], sys_path)
if not bases: if not bases:
return set() return NO_CONTEXTS
# We can take the first element, because only the os special # We can take the first element, because only the os special
# case yields multiple modules, which is not important for # case yields multiple modules, which is not important for
# further imports. # further imports.
@@ -323,7 +324,7 @@ class Importer(object):
except AttributeError: except AttributeError:
# The module is not a package. # The module is not a package.
_add_error(self.module_context, import_path[-1]) _add_error(self.module_context, import_path[-1])
return set() return NO_CONTEXTS
else: else:
paths = method() paths = method()
debug.dbg('search_module %s in paths %s', module_name, paths) debug.dbg('search_module %s in paths %s', module_name, paths)
@@ -340,7 +341,7 @@ class Importer(object):
module_path = None module_path = None
if module_path is None: if module_path is None:
_add_error(self.module_context, import_path[-1]) _add_error(self.module_context, import_path[-1])
return set() return NO_CONTEXTS
else: else:
parent_module = None parent_module = None
try: try:
@@ -356,7 +357,7 @@ class Importer(object):
except ImportError: except ImportError:
# The module is not a package. # The module is not a package.
_add_error(self.module_context, import_path[-1]) _add_error(self.module_context, import_path[-1])
return set() return NO_CONTEXTS
code = None code = None
if is_pkg: if is_pkg:
@@ -383,10 +384,10 @@ class Importer(object):
if module is None: if module is None:
# The file might raise an ImportError e.g. and therefore not be # The file might raise an ImportError e.g. and therefore not be
# importable. # importable.
return set() return NO_CONTEXTS
self._evaluator.modules[module_name] = module self._evaluator.modules[module_name] = module
return set([module]) return ContextSet(module)
def _generate_name(self, name, in_module=None): def _generate_name(self, name, in_module=None):
# Create a pseudo import to be able to follow them. # Create a pseudo import to be able to follow them.

View File

@@ -11,6 +11,7 @@ from jedi.evaluate.param import AbstractArguments, AnonymousArguments
from jedi.cache import memoize_method from jedi.cache import memoize_method
from jedi.evaluate import representation as er from jedi.evaluate import representation as er
from jedi.evaluate import iterable from jedi.evaluate import iterable
from jedi.common import ContextSet, iterator_to_context_set
from jedi.parser_utils import get_parent_scope from jedi.parser_utils import get_parent_scope
@@ -250,6 +251,7 @@ class CompiledInstanceName(compiled.CompiledName):
super(CompiledInstanceName, self).__init__(evaluator, parent_context, name) super(CompiledInstanceName, self).__init__(evaluator, parent_context, name)
self._instance = instance self._instance = instance
@iterator_to_context_set
def infer(self): def infer(self):
for result_context in super(CompiledInstanceName, self).infer(): for result_context in super(CompiledInstanceName, self).infer():
if isinstance(result_context, er.FunctionContext): if isinstance(result_context, er.FunctionContext):
@@ -311,9 +313,7 @@ class CompiledBoundMethod(compiled.CompiledObject):
class InstanceNameDefinition(filters.TreeNameDefinition): class InstanceNameDefinition(filters.TreeNameDefinition):
def infer(self): def infer(self):
contexts = super(InstanceNameDefinition, self).infer() return super(InstanceNameDefinition, self).infer()
for context in contexts:
yield context
class LazyInstanceName(filters.TreeNameDefinition): class LazyInstanceName(filters.TreeNameDefinition):
@@ -331,6 +331,7 @@ class LazyInstanceName(filters.TreeNameDefinition):
class LazyInstanceClassName(LazyInstanceName): class LazyInstanceClassName(LazyInstanceName):
@iterator_to_context_set
def infer(self): def infer(self):
for result_context in super(LazyInstanceClassName, self).infer(): for result_context in super(LazyInstanceClassName, self).infer():
if isinstance(result_context, er.FunctionContext): if isinstance(result_context, er.FunctionContext):

View File

@@ -35,6 +35,7 @@ from jedi.evaluate import recursion
from jedi.evaluate.cache import evaluator_method_cache from jedi.evaluate.cache import evaluator_method_cache
from jedi.evaluate.filters import DictFilter, AbstractNameDefinition, \ from jedi.evaluate.filters import DictFilter, AbstractNameDefinition, \
ParserTreeFilter ParserTreeFilter
from jedi.common import ContextSet, NO_CONTEXTS
from jedi.parser_utils import get_comp_fors from jedi.parser_utils import get_comp_fors
@@ -87,7 +88,7 @@ class SpecialMethodFilter(DictFilter):
# always only going to be one name. The same is true for the # always only going to be one name. The same is true for the
# inferred values. # inferred values.
builtin_func = next(iter(filter.get(self.string_name)[0].infer())) builtin_func = next(iter(filter.get(self.string_name)[0].infer()))
return set([BuiltinMethod(self.parent_context, self._callable, builtin_func)]) return ContextSet(BuiltinMethod(self.parent_context, self._callable, builtin_func))
def __init__(self, context, dct, builtin_context): def __init__(self, context, dct, builtin_context):
super(SpecialMethodFilter, self).__init__(dct) super(SpecialMethodFilter, self).__init__(dct)
@@ -304,7 +305,7 @@ class ListComprehension(ArrayMixin, Comprehension):
def py__getitem__(self, index): def py__getitem__(self, index):
if isinstance(index, slice): if isinstance(index, slice):
return set([self]) return ContextSet(self)
all_types = list(self.py__iter__()) all_types = list(self.py__iter__())
return all_types[index].infer() return all_types[index].infer()
@@ -339,11 +340,11 @@ class DictComprehension(ArrayMixin, Comprehension):
@register_builtin_method('values') @register_builtin_method('values')
def _imitate_values(self): def _imitate_values(self):
lazy_context = context.LazyKnownContexts(self.dict_values()) lazy_context = context.LazyKnownContexts(self.dict_values())
return set([FakeSequence(self.evaluator, 'list', [lazy_context])]) return ContextSet(FakeSequence(self.evaluator, 'list', [lazy_context]))
@register_builtin_method('items') @register_builtin_method('items')
def _imitate_items(self): def _imitate_items(self):
items = set( items = ContextSet.from_iterable(
FakeSequence( FakeSequence(
self.evaluator, 'tuple' self.evaluator, 'tuple'
(context.LazyKnownContexts(keys), context.LazyKnownContexts(values)) (context.LazyKnownContexts(keys), context.LazyKnownContexts(values))
@@ -385,7 +386,7 @@ class SequenceLiteralContext(ArrayMixin, AbstractSequence):
# Can raise an IndexError # Can raise an IndexError
if isinstance(index, slice): if isinstance(index, slice):
return set([self]) return ContextSet(self)
else: else:
return self._defining_context.eval_node(self._items()[index]) return self._defining_context.eval_node(self._items()[index])
@@ -396,7 +397,7 @@ class SequenceLiteralContext(ArrayMixin, AbstractSequence):
""" """
if self.array_type == 'dict': if self.array_type == 'dict':
# Get keys. # Get keys.
types = set() types = ContextSet()
for k, _ in self._items(): for k, _ in self._items():
types |= self._defining_context.eval_node(k) types |= self._defining_context.eval_node(k)
# We don't know which dict index comes first, therefore always # We don't know which dict index comes first, therefore always
@@ -470,7 +471,7 @@ class DictLiteralContext(SequenceLiteralContext):
@register_builtin_method('values') @register_builtin_method('values')
def _imitate_values(self): def _imitate_values(self):
lazy_context = context.LazyKnownContexts(self.dict_values()) lazy_context = context.LazyKnownContexts(self.dict_values())
return set([FakeSequence(self.evaluator, 'list', [lazy_context])]) return ContextSet(FakeSequence(self.evaluator, 'list', [lazy_context]))
@register_builtin_method('items') @register_builtin_method('items')
def _imitate_items(self): def _imitate_items(self):
@@ -482,7 +483,7 @@ class DictLiteralContext(SequenceLiteralContext):
)) for key_node, value_node in self._items() )) for key_node, value_node in self._items()
] ]
return set([FakeSequence(self.evaluator, 'list', lazy_contexts)]) return ContextSet(FakeSequence(self.evaluator, 'list', lazy_contexts))
class _FakeArray(SequenceLiteralContext): class _FakeArray(SequenceLiteralContext):
@@ -506,7 +507,7 @@ class FakeSequence(_FakeArray):
return self._context_list return self._context_list
def py__getitem__(self, index): def py__getitem__(self, index):
return set(self._lazy_context_list[index].infer()) return self._lazy_context_list[index].infer()
def py__iter__(self): def py__iter__(self):
return self._lazy_context_list return self._lazy_context_list
@@ -641,7 +642,7 @@ def py__iter__types(evaluator, types, contextualized_node=None):
def py__getitem__(evaluator, context, types, trailer): def py__getitem__(evaluator, context, types, trailer):
from jedi.evaluate.representation import ClassContext from jedi.evaluate.representation import ClassContext
from jedi.evaluate.instance import TreeInstance from jedi.evaluate.instance import TreeInstance
result = set() result = ContextSet()
trailer_op, node, trailer_cl = trailer.children trailer_op, node, trailer_cl = trailer.children
assert trailer_op == "[" assert trailer_op == "["
@@ -685,7 +686,7 @@ def py__getitem__(evaluator, context, types, trailer):
try: try:
result |= getitem(index) result |= getitem(index)
except IndexError: except IndexError:
result |= py__iter__types(evaluator, set([typ])) result |= py__iter__types(evaluator, ContextSet(typ))
except KeyError: except KeyError:
# Must be a dict. Lists don't raise KeyErrors. # Must be a dict. Lists don't raise KeyErrors.
result |= typ.dict_values() result |= typ.dict_values()
@@ -696,12 +697,12 @@ def check_array_additions(context, sequence):
""" Just a mapper function for the internal _check_array_additions """ """ Just a mapper function for the internal _check_array_additions """
if sequence.array_type not in ('list', 'set'): if sequence.array_type not in ('list', 'set'):
# TODO also check for dict updates # TODO also check for dict updates
return set() return NO_CONTEXTS
return _check_array_additions(context, sequence) return _check_array_additions(context, sequence)
@evaluator_method_cache(default=set()) @evaluator_method_cache(default=NO_CONTEXTS)
@debug.increase_indent @debug.increase_indent
def _check_array_additions(context, sequence): def _check_array_additions(context, sequence):
""" """
@@ -716,11 +717,11 @@ def _check_array_additions(context, sequence):
module_context = context.get_root_context() module_context = context.get_root_context()
if not settings.dynamic_array_additions or isinstance(module_context, compiled.CompiledObject): if not settings.dynamic_array_additions or isinstance(module_context, compiled.CompiledObject):
debug.dbg('Dynamic array search aborted.', color='MAGENTA') debug.dbg('Dynamic array search aborted.', color='MAGENTA')
return set() return ContextSet()
def find_additions(context, arglist, add_name): def find_additions(context, arglist, add_name):
params = list(param.TreeArguments(context.evaluator, context, arglist).unpack()) params = list(param.TreeArguments(context.evaluator, context, arglist).unpack())
result = set() result = ContextSet()
if add_name in ['insert']: if add_name in ['insert']:
params = params[1:] params = params[1:]
if add_name in ['append', 'add', 'insert']: if add_name in ['append', 'add', 'insert']:
@@ -728,7 +729,9 @@ def _check_array_additions(context, sequence):
result.add(lazy_context) result.add(lazy_context)
elif add_name in ['extend', 'update']: elif add_name in ['extend', 'update']:
for key, lazy_context in params: for key, lazy_context in params:
result |= set(py__iter__(context.evaluator, lazy_context.infer())) result |= ContextSet.from_iterable(
py__iter__(context.evaluator, lazy_context.infer())
)
return result return result
temp_param_add, settings.dynamic_params_for_other_modules = \ temp_param_add, settings.dynamic_params_for_other_modules = \
@@ -737,7 +740,7 @@ def _check_array_additions(context, sequence):
is_list = sequence.name.string_name == 'list' is_list = sequence.name.string_name == 'list'
search_names = (['append', 'extend', 'insert'] if is_list else ['add', 'update']) search_names = (['append', 'extend', 'insert'] if is_list else ['add', 'update'])
added_types = set() added_types = NO_CONTEXTS()
for add_name in search_names: for add_name in search_names:
try: try:
possible_names = module_context.tree_node.get_used_names()[add_name] possible_names = module_context.tree_node.get_used_names()[add_name]
@@ -870,7 +873,7 @@ def create_index_types(evaluator, context, index):
""" """
if index == ':': if index == ':':
# Like array[:] # Like array[:]
return set([Slice(context, None, None, None)]) return ContextSet(Slice(context, None, None, None))
elif index.type == 'subscript' and not index.children[0] == '.': elif index.type == 'subscript' and not index.children[0] == '.':
# subscript basically implies a slice operation, except for Python 2's # subscript basically implies a slice operation, except for Python 2's
@@ -888,7 +891,7 @@ def create_index_types(evaluator, context, index):
result.append(el) result.append(el)
result += [None] * (3 - len(result)) result += [None] * (3 - len(result))
return set([Slice(context, *result)]) return ContextSet(Slice(context, *result))
# No slices # No slices
return context.eval_node(index) return context.eval_node(index)

View File

@@ -237,7 +237,7 @@ class ExecutedParam(object):
pep0484_hints = pep0484.infer_param(self._execution_context, self._param_node) pep0484_hints = pep0484.infer_param(self._execution_context, self._param_node)
doc_params = docstrings.infer_param(self._execution_context, self._param_node) doc_params = docstrings.infer_param(self._execution_context, self._param_node)
if pep0484_hints or doc_params: if pep0484_hints or doc_params:
return list(set(pep0484_hints) | set(doc_params)) return pep0484_hints | doc_params
return self._lazy_context.infer() return self._lazy_context.infer()

View File

@@ -19,7 +19,6 @@ x support for type hint comments for functions, `# type: (int, str) -> int`.
See comment from Guido https://github.com/davidhalter/jedi/issues/662 See comment from Guido https://github.com/davidhalter/jedi/issues/662
""" """
import itertools
import os import os
import re import re
@@ -30,6 +29,7 @@ from jedi.evaluate.utils import unite
from jedi.evaluate.cache import evaluator_method_cache from jedi.evaluate.cache import evaluator_method_cache
from jedi.evaluate import compiled from jedi.evaluate import compiled
from jedi.evaluate.context import LazyTreeContext from jedi.evaluate.context import LazyTreeContext
from jedi.common import NO_CONTEXTS
from jedi import debug from jedi import debug
from jedi import _compatibility from jedi import _compatibility
from jedi import parser_utils from jedi import parser_utils
@@ -42,16 +42,15 @@ def _evaluate_for_annotation(context, annotation, index=None):
and we're interested in that index and we're interested in that index
""" """
if annotation is not None: if annotation is not None:
definitions = context.eval_node( context_set = context.eval_node(_fix_forward_reference(context, annotation))
_fix_forward_reference(context, annotation))
if index is not None: if index is not None:
definitions = list(itertools.chain.from_iterable( context_set = context_set.filter(
definition.py__getitem__(index) for definition in definitions lambda context: context.array_type == 'tuple' \
if definition.array_type == 'tuple' and and len(list(context.py__iter__())) >= index
len(list(definition.py__iter__())) >= index)) ).py__getitem__(index)
return unite(d.execute_evaluated() for d in definitions) return context_set.execute_evaluated()
else: else:
return set() return NO_CONTEXTS
def _fix_forward_reference(context, node): def _fix_forward_reference(context, node):

View File

@@ -7,6 +7,7 @@ from jedi._compatibility import unicode
from jedi import debug from jedi import debug
from jedi.evaluate.compiled import CompiledObject, create, builtin_from_name from jedi.evaluate.compiled import CompiledObject, create, builtin_from_name
from jedi.evaluate import analysis from jedi.evaluate import analysis
from jedi.common import ContextSet
# Maps Python syntax to the operator module. # Maps Python syntax to the operator module.
COMPARISON_OPERATORS = { COMPARISON_OPERATORS = {
@@ -33,7 +34,7 @@ def literals_to_types(evaluator, result):
new_result |= cls.execute_evaluated() new_result |= cls.execute_evaluated()
else: else:
new_result.add(typ) new_result.add(typ)
return new_result return ContextSet.from_set(new_result)
def calculate_children(evaluator, context, children): def calculate_children(evaluator, context, children):
@@ -49,7 +50,7 @@ def calculate_children(evaluator, context, children):
# handle lazy evaluation of and/or here. # handle lazy evaluation of and/or here.
if operator in ('and', 'or'): if operator in ('and', 'or'):
left_bools = set([left.py__bool__() for left in types]) left_bools = ContextSet(left.py__bool__() for left in types)
if left_bools == set([True]): if left_bools == set([True]):
if operator == 'and': if operator == 'and':
types = context.eval_node(right) types = context.eval_node(right)
@@ -65,22 +66,22 @@ def calculate_children(evaluator, context, children):
def calculate(evaluator, context, left_result, operator, right_result): def calculate(evaluator, context, left_result, operator, right_result):
result = set()
if not left_result or not right_result: if not left_result or not right_result:
# illegal slices e.g. cause left/right_result to be None # illegal slices e.g. cause left/right_result to be None
result = (left_result or set()) | (right_result or set()) result = (left_result or set()) | (right_result or set())
result = literals_to_types(evaluator, result) return literals_to_types(evaluator, result)
else: else:
# I don't think there's a reasonable chance that a string # I don't think there's a reasonable chance that a string
# operation is still correct, once we pass something like six # operation is still correct, once we pass something like six
# objects. # objects.
if len(left_result) * len(right_result) > 6: if len(left_result) * len(right_result) > 6:
result = literals_to_types(evaluator, left_result | right_result) return literals_to_types(evaluator, left_result | right_result)
else: else:
for left in left_result: return ContextSet.from_sets(
for right in right_result: _element_calculate(evaluator, context, left, operator, right)
result |= _element_calculate(evaluator, context, left, operator, right) for left in left_result
return result for right in right_result
)
def factor_calculate(evaluator, types, operator): def factor_calculate(evaluator, types, operator):
@@ -131,21 +132,21 @@ def _element_calculate(evaluator, context, left, operator, right):
if operator == '*': if operator == '*':
# for iterables, ignore * operations # for iterables, ignore * operations
if isinstance(left, iterable.AbstractSequence) or is_string(left): if isinstance(left, iterable.AbstractSequence) or is_string(left):
return set([left]) return ContextSet(left)
elif isinstance(right, iterable.AbstractSequence) or is_string(right): elif isinstance(right, iterable.AbstractSequence) or is_string(right):
return set([right]) return ContextSet(right)
elif operator == '+': elif operator == '+':
if l_is_num and r_is_num or is_string(left) and is_string(right): if l_is_num and r_is_num or is_string(left) and is_string(right):
return set([create(evaluator, left.obj + right.obj)]) return ContextSet(create(evaluator, left.obj + right.obj))
elif _is_tuple(left) and _is_tuple(right) or _is_list(left) and _is_list(right): elif _is_tuple(left) and _is_tuple(right) or _is_list(left) and _is_list(right):
return set([iterable.MergedArray(evaluator, (left, right))]) return ContextSet(iterable.MergedArray(evaluator, (left, right)))
elif operator == '-': elif operator == '-':
if l_is_num and r_is_num: if l_is_num and r_is_num:
return set([create(evaluator, left.obj - right.obj)]) return ContextSet(create(evaluator, left.obj - right.obj))
elif operator == '%': elif operator == '%':
# With strings and numbers the left type typically remains. Except for # With strings and numbers the left type typically remains. Except for
# `int() % float()`. # `int() % float()`.
return set([left]) return ContextSet(left)
elif operator in COMPARISON_OPERATORS: elif operator in COMPARISON_OPERATORS:
operation = COMPARISON_OPERATORS[operator] operation = COMPARISON_OPERATORS[operator]
if isinstance(left, CompiledObject) and isinstance(right, CompiledObject): if isinstance(left, CompiledObject) and isinstance(right, CompiledObject):
@@ -157,9 +158,9 @@ def _element_calculate(evaluator, context, left, operator, right):
result = operation(left, right) result = operation(left, right)
except TypeError: except TypeError:
# Could be True or False. # Could be True or False.
return set([create(evaluator, True), create(evaluator, False)]) return ContextSet(create(evaluator, True), create(evaluator, False))
else: else:
return set([create(evaluator, result)]) return ContextSet(create(evaluator, result))
elif operator == 'in': elif operator == 'in':
return set() return set()
@@ -175,4 +176,4 @@ def _element_calculate(evaluator, context, left, operator, right):
analysis.add(context, 'type-error-operation', operator, analysis.add(context, 'type-error-operation', operator,
message % (left, right)) message % (left, right))
return set([left, right]) return ContextSet(left, right)

View File

@@ -64,6 +64,7 @@ from jedi.evaluate.filters import ParserTreeFilter, FunctionExecutionFilter, \
ContextNameMixin ContextNameMixin
from jedi.evaluate import context from jedi.evaluate import context
from jedi.evaluate.context import ContextualizedNode from jedi.evaluate.context import ContextualizedNode
from jedi.common import NO_CONTEXTS, ContextSet, iterator_to_context_set
from jedi import parser_utils from jedi import parser_utils
from jedi.evaluate.parser_cache import get_yield_exprs from jedi.evaluate.parser_cache import get_yield_exprs
@@ -83,6 +84,7 @@ class ClassName(TreeNameDefinition):
super(ClassName, self).__init__(parent_context, tree_name) super(ClassName, self).__init__(parent_context, tree_name)
self._name_context = name_context self._name_context = name_context
@iterator_to_context_set
def infer(self): def infer(self):
# TODO this _name_to_types might get refactored and be a part of the # TODO this _name_to_types might get refactored and be a part of the
# parent class. Once it is, we can probably just overwrite method to # parent class. Once it is, we can probably just overwrite method to
@@ -162,7 +164,7 @@ class ClassContext(use_metaclass(CachedMetaClass, context.TreeContext)):
def py__call__(self, params): def py__call__(self, params):
from jedi.evaluate.instance import TreeInstance from jedi.evaluate.instance import TreeInstance
return set([TreeInstance(self.evaluator, self.parent_context, self, params)]) return ContextSet(TreeInstance(self.evaluator, self.parent_context, self, params))
def py__class__(self): def py__class__(self):
return compiled.create(self.evaluator, type) return compiled.create(self.evaluator, type)
@@ -227,7 +229,7 @@ class LambdaName(AbstractNameDefinition):
return self._lambda_context.tree_node.start_pos return self._lambda_context.tree_node.start_pos
def infer(self): def infer(self):
return set([self._lambda_context]) return ContextSet(self._lambda_context)
class FunctionContext(use_metaclass(CachedMetaClass, context.TreeContext)): class FunctionContext(use_metaclass(CachedMetaClass, context.TreeContext)):
@@ -260,7 +262,7 @@ class FunctionContext(use_metaclass(CachedMetaClass, context.TreeContext)):
""" """
yield_exprs = get_yield_exprs(self.evaluator, self.tree_node) yield_exprs = get_yield_exprs(self.evaluator, self.tree_node)
if yield_exprs: if yield_exprs:
return set([iterable.Generator(self.evaluator, function_execution)]) return ContextSet(iterable.Generator(self.evaluator, function_execution))
else: else:
return function_execution.get_return_values() return function_execution.get_return_values()
@@ -312,7 +314,7 @@ class FunctionExecutionContext(context.TreeContext):
self.tree_node = function_context.tree_node self.tree_node = function_context.tree_node
self.var_args = var_args self.var_args = var_args
@evaluator_method_cache(default=set()) @evaluator_method_cache(default=NO_CONTEXTS)
@recursion.execution_recursion_decorator() @recursion.execution_recursion_decorator()
def get_return_values(self, check_yields=False): def get_return_values(self, check_yields=False):
funcdef = self.tree_node funcdef = self.tree_node
@@ -320,12 +322,12 @@ class FunctionExecutionContext(context.TreeContext):
return self.evaluator.eval_element(self, funcdef.children[-1]) return self.evaluator.eval_element(self, funcdef.children[-1])
if check_yields: if check_yields:
types = set() context_set = NO_CONTEXTS
returns = get_yield_exprs(self.evaluator, funcdef) returns = get_yield_exprs(self.evaluator, funcdef)
else: else:
returns = funcdef.iter_return_stmts() returns = funcdef.iter_return_stmts()
types = set(docstrings.infer_return_types(self.function_context)) context_set = docstrings.infer_return_types(self.function_context)
types |= set(pep0484.infer_return_types(self.function_context)) context_set |= pep0484.infer_return_types(self.function_context)
for r in returns: for r in returns:
check = flow_analysis.reachability_check(self, funcdef, r) check = flow_analysis.reachability_check(self, funcdef, r)
@@ -333,18 +335,18 @@ class FunctionExecutionContext(context.TreeContext):
debug.dbg('Return unreachable: %s', r) debug.dbg('Return unreachable: %s', r)
else: else:
if check_yields: if check_yields:
types |= set(self._eval_yield(r)) context_set |= ContextSet(self._eval_yield(r))
else: else:
try: try:
children = r.children children = r.children
except AttributeError: except AttributeError:
types.add(compiled.create(self.evaluator, None)) context_set |= ContextSet(compiled.create(self.evaluator, None))
else: else:
types |= self.eval_node(children[1]) context_set |= self.eval_node(children[1])
if check is flow_analysis.REACHABLE: if check is flow_analysis.REACHABLE:
debug.dbg('Return reachable: %s', r) debug.dbg('Return reachable: %s', r)
break break
return types return context_set
def _eval_yield(self, yield_expr): def _eval_yield(self, yield_expr):
if yield_expr.type == 'keyword': if yield_expr.type == 'keyword':
@@ -430,8 +432,10 @@ class ModuleAttributeName(AbstractNameDefinition):
self.string_name = string_name self.string_name = string_name
def infer(self): def infer(self):
return compiled.create(self.parent_context.evaluator, str).execute( return ContextSet(
param.ValuesArguments([]) compiled.create(self.parent_context.evaluator, str).execute(
param.ValuesArguments([])
)
) )
@@ -628,7 +632,7 @@ class ImplicitNSName(AbstractNameDefinition):
self.string_name = string_name self.string_name = string_name
def infer(self): def infer(self):
return [] return NO_CONTEXTS
def get_root_context(self): def get_root_context(self):
return self.implicit_ns_context return self.implicit_ns_context