Fix some nonlocal/global issues.

This commit is contained in:
Dave Halter
2017-07-30 01:02:36 +02:00
parent 4ff25e8c92
commit b7726d05cf
2 changed files with 73 additions and 5 deletions

View File

@@ -93,11 +93,14 @@ def _is_future_import_first(import_from):
class Context(object):
def __init__(self, node, parent_context=None):
def __init__(self, node, add_syntax_error, parent_context=None):
self.node = node
self.blocks = []
self.parent_context = parent_context
self.used_names = dict()
self._used_name_dict = {}
self._global_names = []
self._nonlocal_names = []
self._add_syntax_error = add_syntax_error
def is_async_funcdef(self):
# Stupidly enough async funcdefs can have two different forms,
@@ -105,10 +108,44 @@ class Context(object):
return self.is_function() \
and self.node.parent.type in ('async_funcdef', 'async_stmt')
def is_function(self):
return self.node.type == 'funcdef'
def add_name(self, name):
parent_type = name.parent.type
if parent_type == 'trailer':
# We are only interested in first level names.
return
if parent_type == 'global_stmt':
self._global_names.append(name)
elif parent_type == 'nonlocal_stmt':
self._nonlocal_names.append(name)
else:
self._used_name_dict.setdefault(name.value, []).append(name)
def finalize(self):
self._analyze_names(self._global_names, 'global')
self._analyze_names(self._nonlocal_names, 'nonlocal')
def _analyze_names(self, globals_or_nonlocals, type_):
for base_name in globals_or_nonlocals:
search = base_name.value
# Somehow Python does it the reversed way.
for name in reversed(self._used_name_dict.get(search, [])):
print(name, name.parent)
if name.start_pos > base_name.start_pos:
# All following names don't have to be checked.
break
if name.is_definition():
message = "name '%s' is assigned to before %s declaration"
else:
message = "name '%s' is used prior to %s declaration"
self._add_syntax_error(message % (name.value, type_), base_name)
# Only add an error for the first occurence.
break
@contextmanager
def add_block(self, node):
self.blocks.append(node)
@@ -117,7 +154,9 @@ class Context(object):
@contextmanager
def add_context(self, node):
yield Context(node, parent_context=self)
new_context = Context(node, self._add_syntax_error, parent_context=self)
yield new_context
new_context.finalize()
class ErrorFinder(Normalizer):
@@ -135,7 +174,7 @@ class ErrorFinder(Normalizer):
parent_scope = node
else:
parent_scope = search_ancestor(node, allowed)
self._context = Context(parent_scope)
self._context = Context(parent_scope, self._add_syntax_error)
@contextmanager
def visit_node(self, node):
@@ -392,6 +431,8 @@ class ErrorFinder(Normalizer):
if leaf.value == '__debug__' and leaf.is_definition():
message = 'assignment to keyword'
self._add_syntax_error(message, leaf)
self._context.add_name(leaf)
elif leaf.type == 'string':
if 'b' in leaf.string_prefix.lower() \
and any(c for c in leaf.value if ord(c) > 127):

View File

@@ -150,6 +150,32 @@ def test_indentation_errors(code, positions):
'def f(x, x): pass',
'def x(): from math import *',
'nonlocal a',
dedent('''
def glob():
x = 3
x.z
global x'''),
dedent('''
def glob():
x = 3
global x'''),
dedent('''
def glob():
x
global x'''),
dedent('''
def glob():
x = 3
x.z
nonlocal x'''),
dedent('''
def glob():
x = 3
nonlocal x'''),
dedent('''
def glob():
x
nonlocal x'''),
# IndentationError
' foo',
@@ -195,6 +221,7 @@ def test_python_exception_matches(code):
('{**{} for a in [1]}', '3.5'),
('"s" b""', '3.5'),
('b"ä"', '3.5'),
#('(%s *d) = x' % ('a,' * 256), '3.5')
]
)
def test_python_exception_matches_version(code, version):