diff --git a/jedi/evaluate/compiled/access.py b/jedi/evaluate/compiled/access.py index 691471fe..171873ac 100644 --- a/jedi/evaluate/compiled/access.py +++ b/jedi/evaluate/compiled/access.py @@ -455,8 +455,10 @@ if py_version >= 35: CoroutineType = type(_coroutine) _coroutine.close() # Prevent ResourceWarning """), 'blub', 'exec')) + _coroutine_wrapper = _coroutine.__await__() else: _coroutine = None + _coroutine_wrapper = None if py_version >= 36: exec(compile(dedent(""" @@ -475,6 +477,7 @@ class _SPECIAL_OBJECTS(object): GENERATOR_OBJECT = _a_generator(1.0) BUILTINS = builtins COROUTINE = _coroutine + COROUTINE_WRAPPER = _coroutine_wrapper ASYNC_GENERATOR = _async_generator diff --git a/jedi/evaluate/context/asynchronous.py b/jedi/evaluate/context/asynchronous.py index 8b82dc35..4a3ec8ef 100644 --- a/jedi/evaluate/context/asynchronous.py +++ b/jedi/evaluate/context/asynchronous.py @@ -1,17 +1,18 @@ from jedi.evaluate.filters import publish_method, BuiltinOverwrite +from jedi.evaluate.base_context import ContextSet class AsyncBase(BuiltinOverwrite): def __init__(self, evaluator, func_execution_context): super(AsyncBase, self).__init__(evaluator) - self._func_execution_context = func_execution_context + self.func_execution_context = func_execution_context @property def name(self): - return self.get_builtin_object().py__name__() + return self.get_object().name def __repr__(self): - return "<%s of %s>" % (type(self).__name__, self._func_execution_context) + return "<%s of %s>" % (type(self).__name__, self.func_execution_context) class Coroutine(AsyncBase): @@ -19,7 +20,14 @@ class Coroutine(AsyncBase): @publish_method('__await__') def _await(self): - return self._func_execution_context.get_return_values() + return ContextSet(CoroutineWrapper(self.evaluator, self.func_execution_context)) + + +class CoroutineWrapper(AsyncBase): + special_object_identifier = u'COROUTINE_WRAPPER' + + def py__stop_iteration_returns(self): + return self.func_execution_context.get_return_values() class AsyncGenerator(AsyncBase): diff --git a/jedi/evaluate/syntax_tree.py b/jedi/evaluate/syntax_tree.py index 2ea81409..ce5c329b 100644 --- a/jedi/evaluate/syntax_tree.py +++ b/jedi/evaluate/syntax_tree.py @@ -50,6 +50,18 @@ def _limit_context_infers(func): return wrapper +def _py__stop_iteration_returns(generators): + results = ContextSet() + for generator in generators: + try: + method = generator.py__stop_iteration_returns + except AttributeError: + debug.warning('%s is not actually a generator', generator) + else: + results |= method() + return results + + @debug.increase_indent @_limit_context_infers def eval_node(context, element): @@ -94,7 +106,8 @@ def eval_node(context, element): await_context_set = context_set.py__getattribute__(u"__await__") if not await_context_set: debug.warning('Tried to run py__await__ on context %s', context) - return await_context_set.execute_evaluated() + context_set = ContextSet() + return _py__stop_iteration_returns(await_context_set.execute_evaluated()) return context_set elif typ in ('testlist_star_expr', 'testlist',): # The implicit tuple in statements. @@ -129,15 +142,7 @@ def eval_node(context, element): # Implies that it's a yield from. element = element.children[1].children[1] generators = context.eval_node(element) - results = ContextSet() - for generator in generators: - try: - method = generator.py__stop_iteration_returns - except AttributeError: - debug.warning('%s is not actually a generator', generator) - else: - results |= method() - return results + return _py__stop_iteration_returns(generators) # Generator.send() is not implemented. return NO_CONTEXTS @@ -326,7 +331,7 @@ def eval_or_test(context, or_test): # Otherwise continue, because of uncertainty. else: types = _eval_comparison(context.evaluator, context, types, operator, - context.eval_node(right)) + context.eval_node(right)) debug.dbg('eval_or_test types %s', types) return types diff --git a/test/completion/async_.py b/test/completion/async_.py index cdee73c1..a5484173 100644 --- a/test/completion/async_.py +++ b/test/completion/async_.py @@ -24,8 +24,8 @@ async def y(): argh = await x() #? int() argh - #? int() - x().__await__() + #? ['__next__'] + x().__await__().__next return 2 async def x2(): @@ -51,8 +51,7 @@ class Awaitable: async def awaitable_test(): foo = await Awaitable() - # TODO doesn't work yet. - ##? int() + #? str() foo # python >= 3.6