Make get_module_names return top-level async functions when all_scopes=False.

This commit is contained in:
forest93
2019-04-11 23:38:55 +08:00
parent c801e24afc
commit 1e12e1e318
2 changed files with 23 additions and 4 deletions

View File

@@ -190,6 +190,9 @@ def get_module_names(module, all_scopes):
# but that would be a big change that could break type inference, whereas for now # but that would be a big change that could break type inference, whereas for now
# this discrepancy looks like only a problem for "get_module_names". # this discrepancy looks like only a problem for "get_module_names".
parent_scope = parent_scope.parent parent_scope = parent_scope.parent
# async functions have an extra wrapper. Strip it.
if parent_scope and parent_scope.type == 'async_stmt':
parent_scope = parent_scope.parent
return parent_scope in (module, None) return parent_scope in (module, None)
names = [n for n in names if is_module_scope_name(n)] names = [n for n in names if is_module_scope_name(n)]

View File

@@ -82,18 +82,34 @@ class TestDefinedNames(TestCase):
def test_class_fields_with_all_scopes_false(self): def test_class_fields_with_all_scopes_false(self):
definitions = self.check_defined_names(""" definitions = self.check_defined_names("""
from module import f from module import f
import asyncio
g = f(f) g = f(f)
class C: class C:
h = g h = g
def __init__(self):
pass
async def __aenter__(self):
pass
def foo(x=a): def foo(x=a):
bar = x bar = x
return bar return bar
""", ['f', 'g', 'C', 'foo'])
C_subdefs = definitions[-2].defined_names() async def async_foo(duration):
foo_subdefs = definitions[-1].defined_names() async def wait():
self.assert_definition_names(C_subdefs, ['h']) await asyncio.sleep(100)
for i in range(duration//100):
await wait()
return duration//100*100
""", ['f', 'asyncio', 'g', 'C', 'foo', 'async_foo'])
C_subdefs = definitions[-3].defined_names()
foo_subdefs = definitions[-2].defined_names()
async_foo_subdefs = definitions[-1].defined_names()
self.assert_definition_names(C_subdefs, ['h', '__init__', '__aenter__'])
self.assert_definition_names(foo_subdefs, ['x', 'bar']) self.assert_definition_names(foo_subdefs, ['x', 'bar'])
self.assert_definition_names(async_foo_subdefs, ['duration', 'wait', 'i'])
def test_follow_imports(environment): def test_follow_imports(environment):