diff --git a/jedi/evaluate/base_context.py b/jedi/evaluate/base_context.py index 693a99aa..c53e6d35 100644 --- a/jedi/evaluate/base_context.py +++ b/jedi/evaluate/base_context.py @@ -63,10 +63,13 @@ class Context(BaseContext): arguments = ValuesArguments([ContextSet(value) for value in value_list]) return self.execute(arguments) - def iterate(self, contextualized_node=None): + def iterate(self, contextualized_node=None, is_async=False): debug.dbg('iterate') try: - iter_method = self.py__iter__ + if is_async: + iter_method = self.py__aiter__ + else: + iter_method = self.py__iter__ except AttributeError: if contextualized_node is not None: from jedi.evaluate import analysis @@ -241,9 +244,9 @@ class ContextSet(BaseContextSet): def py__class__(self): return ContextSet.from_iterable(c.py__class__() for c in self._set) - def iterate(self, contextualized_node=None): + def iterate(self, contextualized_node=None, is_async=False): from jedi.evaluate.lazy_context import get_merged_lazy_context - type_iters = [c.iterate(contextualized_node) for c in self._set] + type_iters = [c.iterate(contextualized_node, is_async=is_async) for c in self._set] for lazy_contexts in zip_longest(*type_iters): yield get_merged_lazy_context( [l for l in lazy_contexts if l is not None] diff --git a/jedi/evaluate/context/function.py b/jedi/evaluate/context/function.py index 0dba9c91..b9c481e3 100644 --- a/jedi/evaluate/context/function.py +++ b/jedi/evaluate/context/function.py @@ -63,11 +63,19 @@ class FunctionContext(use_metaclass(CachedMetaClass, TreeContext)): """ Created to be used by inheritance. """ - yield_exprs = get_yield_exprs(self.evaluator, self.tree_node) - if yield_exprs: - return ContextSet(iterable.Generator(self.evaluator, function_execution)) + is_coroutine = self.tree_node.parent.type == 'async_stmt' + is_generator = bool(get_yield_exprs(self.evaluator, self.tree_node)) + + if is_coroutine: + if is_generator: + return ContextSet(iterable.AsyncGenerator(self.evaluator, function_execution)) + else: + return ContextSet(iterable.Coroutine(self.evaluator, function_execution)) else: - return function_execution.get_return_values() + if is_generator: + return ContextSet(iterable.Generator(self.evaluator, function_execution)) + else: + return function_execution.get_return_values() def get_function_execution(self, arguments=None): if arguments is None: @@ -169,7 +177,8 @@ class FunctionExecutionContext(TreeContext): yield LazyTreeContext(self, node) @recursion.execution_recursion_decorator(default=iter([])) - def get_yield_values(self): + def get_yield_values(self, is_async=False): + # TODO: if is_async, wrap yield statements in Awaitable/async_generator_asend for_parents = [(y, tree.search_ancestor(y, 'for_stmt', 'funcdef', 'while_stmt', 'if_stmt')) for y in get_yield_exprs(self.evaluator, self.tree_node)] diff --git a/jedi/evaluate/context/iterable.py b/jedi/evaluate/context/iterable.py index d0f468e4..a755d89c 100644 --- a/jedi/evaluate/context/iterable.py +++ b/jedi/evaluate/context/iterable.py @@ -37,6 +37,19 @@ from jedi.evaluate.base_context import ContextSet, NO_CONTEXTS, Context, \ TreeContext, ContextualizedNode from jedi.parser_utils import get_comp_fors +try: + from types import CoroutineType +except ImportError: + HAS_COROUTINE = False +else: + HAS_COROUTINE = True + +try: + from types import AsyncGeneratorType +except ImportError: + HAS_ASYNC_GENERATOR = False +else: + HAS_ASYNC_GENERATOR = True class AbstractIterable(Context): builtin_methods = {} @@ -53,6 +66,83 @@ class AbstractIterable(Context): return compiled.CompiledContextName(self, self.array_type) +@has_builtin_methods +class CoroutineMixin(object): + array_type = None + + def get_filters(self, search_global, until_position=None, origin_scope=None): + gen_obj = compiled.create(self.evaluator, CoroutineType) + yield SpecialMethodFilter(self, self.builtin_methods, gen_obj) + for filter in gen_obj.get_filters(search_global): + yield filter + + def py__bool__(self): + return True + + def py__class__(self): + gen_obj = compiled.create(self.evaluator, CoroutineType) + return gen_obj.py__class__() + + @property + def name(self): + return compiled.CompiledContextName(self, 'coroutine') + + +class Coroutine(CoroutineMixin, Context): + def __init__(self, evaluator, func_execution_context): + if not HAS_COROUTINE: + raise ImportError("Need python3.5 to support coroutines.") + super(Coroutine, self).__init__(evaluator, parent_context=evaluator.BUILTINS) + self._func_execution_context = func_execution_context + + def execute_await(self): + return self._func_execution_context.get_return_values() + + def __repr__(self): + return "<%s of %s>" % (type(self).__name__, self._func_execution_context) + + +@has_builtin_methods +class AsyncGeneratorMixin(object): + array_type = None + + @register_builtin_method('__anext__') + def py__anext__(self): + return ContextSet.from_sets(lazy_context.infer() for lazy_context in self.py__aiter__()) + + def get_filters(self, search_global, until_position=None, origin_scope=None): + gen_obj = compiled.create(self.evaluator, AsyncGeneratorType) + yield SpecialMethodFilter(self, self.builtin_methods, gen_obj) + for filter in gen_obj.get_filters(search_global): + yield filter + + def py__bool__(self): + return True + + def py__class__(self): + gen_obj = compiled.create(self.evaluator, AsyncGeneratorType) + return gen_obj.py__class__() + + @property + def name(self): + return compiled.CompiledContextName(self, 'asyncgenerator') + + +class AsyncGenerator(AsyncGeneratorMixin, Context): + """Handling of `yield` functions.""" + def __init__(self, evaluator, func_execution_context): + if not HAS_ASYNC_GENERATOR: + raise ImportError("Need python3.6 to support async generators.") + super(AsyncGenerator, self).__init__(evaluator, parent_context=evaluator.BUILTINS) + self._func_execution_context = func_execution_context + + def py__aiter__(self): + return self._func_execution_context.get_yield_values(is_async=True) + + def __repr__(self): + return "<%s of %s>" % (type(self).__name__, self._func_execution_context) + + @has_builtin_methods class GeneratorMixin(object): array_type = None @@ -126,17 +216,18 @@ class Comprehension(AbstractIterable): cls = ListComprehension return cls(evaluator, context, atom) - def __init__(self, evaluator, defining_context, atom): + def __init__(self, evaluator, defining_context, atom, is_async=False): super(Comprehension, self).__init__(evaluator) self._defining_context = defining_context self._atom = atom def _get_comprehension(self): + "return 'a for a in b'" # The atom contains a testlist_comp return self._atom.children[1] def _get_comp_for(self): - # The atom contains a testlist_comp + "return CompFor('for a in b')" return self._get_comprehension().children[1] def _eval_node(self, index=0): @@ -154,13 +245,17 @@ class Comprehension(AbstractIterable): def _nested(self, comp_fors, parent_context=None): comp_for = comp_fors[0] - input_node = comp_for.children[3] + + is_async = 'async' == comp_for.children[comp_for.children.index('for') - 1] + + input_node = comp_for.children[comp_for.children.index('in') + 1] parent_context = parent_context or self._defining_context input_types = parent_context.eval_node(input_node) + # TODO: simulate await if self.is_async cn = ContextualizedNode(parent_context, input_node) - iterated = input_types.iterate(cn) - exprlist = comp_for.children[1] + iterated = input_types.iterate(cn, is_async=is_async) + exprlist = comp_for.children[comp_for.children.index('for') + 1] for i, lazy_context in enumerate(iterated): types = lazy_context.infer() dct = unpack_tuple_to_dict(parent_context, types, exprlist) @@ -649,7 +744,7 @@ class _ArrayInstance(object): for addition in additions: yield addition - def iterate(self, contextualized_node=None): + def iterate(self, contextualized_node=None, is_async=False): return self.py__iter__() diff --git a/jedi/evaluate/syntax_tree.py b/jedi/evaluate/syntax_tree.py index 1d847a49..f41f69bb 100644 --- a/jedi/evaluate/syntax_tree.py +++ b/jedi/evaluate/syntax_tree.py @@ -68,22 +68,38 @@ def eval_node(context, element): return eval_expr_stmt(context, element) elif typ in ('power', 'atom_expr'): first_child = element.children[0] - if not (first_child.type == 'keyword' and first_child.value == 'await'): - context_set = eval_atom(context, first_child) - for trailer in element.children[1:]: - if trailer == '**': # has a power operation. - right = evaluator.eval_element(context, element.children[2]) - context_set = _eval_comparison( - evaluator, - context, - context_set, - trailer, - right - ) - break - context_set = eval_trailer(context, context_set, trailer) - return context_set - return NO_CONTEXTS + children = element.children[1:] + had_await = False + if first_child.type == 'keyword' and first_child.value == 'await': + had_await = True + first_child = children.pop(0) + + context_set = eval_atom(context, first_child) + for trailer in children: + if trailer == '**': # has a power operation. + right = evaluator.eval_element(context, children[1]) + context_set = _eval_comparison( + evaluator, + context, + context_set, + trailer, + right + ) + break + context_set = eval_trailer(context, context_set, trailer) + + if had_await: + await_context_set = ContextSet() + for context in context_set: + try: + func = context.execute_await + except AttributeError: + debug.warning('Tried to run execute_await on context %s', context) + pass + else: + await_context_set |= func() + return await_context_set + return context_set elif typ in ('testlist_star_expr', 'testlist',): # The implicit tuple in statements. return ContextSet(iterable.SequenceLiteralContext(evaluator, context, element)) diff --git a/test/completion/async_.py b/test/completion/async_.py index b2202137..797c0dd0 100644 --- a/test/completion/async_.py +++ b/test/completion/async_.py @@ -6,17 +6,39 @@ raise errors or return extremely strange results. """ async def x(): + return 1 + +#? [] +x.cr_awai + +#? ['cr_await'] +x().cr_awai + +a = await x() +#? int() +a + +async def y(): argh = await x() - #? + #? int() argh return 2 -#? int() -x() +async def asgen(): + yield 1 + await asyncio.sleep(0) + yield 2 -a = await x() -#? -a +async def wrapper(): + #? int() + [x async for x in asgen()][0] + + async for y in asgen(): + # TODO: make this an int() + y + +#? ['__anext__'] +asgen().__ane async def x2():