diff --git a/parso/python/normalizer.py b/parso/python/normalizer.py index 72e9cdb..105408c 100644 --- a/parso/python/normalizer.py +++ b/parso/python/normalizer.py @@ -15,10 +15,43 @@ class CompressNormalizer(Normalizer): return leaf.prefix + leaf.value -def is_future_import(from_import): - from_names = from_import.get_from_names() +def _iter_stmts(scope): + """ + Iterates over all statements and splits up simple_stmt. + """ + for child in scope.children: + if child.type == 'simple_stmt': + for child2 in child.children: + if child2.type == 'newline' or child2 == ';': + continue + yield child2 + else: + yield child + + +def _is_future_import(import_from): + if import_from.level != 0: + return False + from_names = import_from.get_from_names() return [n.value for n in from_names] == ['__future__'] + +def _is_future_import_first(import_from): + """ + Checks if the import is the first statement of a file. + """ + found_docstring = False + for stmt in _iter_stmts(import_from.get_root_node()): + if stmt == import_from: + return True + if stmt.type == 'import_from' and _is_future_import(stmt): + continue + if stmt.type == 'string' and not found_docstring: + found_docstring = True + continue + return False + + class Context(object): def __init__(self, node, parent_context=None): self.node = node @@ -92,8 +125,8 @@ class ErrorFinder(Normalizer): yield self._context = context return - elif node.type == 'import_from' and node.level == 0 \ - and is_future_import(node): + elif node.type == 'import_from' \ + and _is_future_import(node) and not _is_future_import_first(node): message = "from __future__ imports must occur at the beginning of the file" self._add_syntax_error(message, node) elif node.type in _STAR_EXPR_PARENTS: diff --git a/test/test_python_errors.py b/test/test_python_errors.py index f596630..9b8adf2 100644 --- a/test/test_python_errors.py +++ b/test/test_python_errors.py @@ -10,8 +10,7 @@ import parso from parso.python.normalizer import ErrorFinderConfig def _get_error_list(code, version=None): - grammar = parso.load_grammar(version=version) - tree = grammar.parse(code) + tree = parso.parse(code, version=version) config = ErrorFinderConfig() return list(tree._get_normalizer_issues(config)) @@ -148,3 +147,26 @@ def test_statically_nested_blocks(): assert get_error(20) assert get_error(20, add_func=True) + + +def test_future_import_first(): + def is_issue(code, *args): + code = code % args + return bool(_get_error_list(code)) + + i1 = 'from __future__ import division' + i2 = 'from __future__ import absolute_import' + assert not is_issue(i1) + assert not is_issue(i1 + ';' + i2) + assert not is_issue(i1 + '\n' + i2) + assert not is_issue('"";' + i1) + assert not is_issue('"";' + i1) + assert not is_issue('""\n' + i1) + assert not is_issue('""\n%s\n%s', i1, i2) + assert not is_issue('""\n%s;%s', i1, i2) + assert not is_issue('"";%s;%s ', i1, i2) + assert not is_issue('"";%s\n%s ', i1, i2) + assert is_issue('1;' + i1) + assert is_issue('1\n' + i1) + assert is_issue('"";1\n' + i1) + assert is_issue('""\n%s\nfrom x import a\n%s', i1, i2)