From 647073b1b928691aba5497b8673e06adc905cbce Mon Sep 17 00:00:00 2001 From: Alisdair Robertson Date: Sat, 28 Oct 2017 22:35:49 +1100 Subject: [PATCH] Iter raise statements in a Function (#16) * Add Function.iter_raise_stmts method and tests * Add Alisdair Robertson to AUTHORS.txt * Cleanup Function.iter_raise_stmts and test Decided not to try and exclude exceptions that would be caught by a try-catch --- AUTHORS.txt | 1 + parso/python/tree.py | 15 +++++++++++++++ test/test_parser_tree.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/AUTHORS.txt b/AUTHORS.txt index c26821a..e16c04c 100644 --- a/AUTHORS.txt +++ b/AUTHORS.txt @@ -5,6 +5,7 @@ David Halter (@davidhalter) Code Contributors ================= +Alisdair Robertson (@robodair) Code Contributors (to Jedi and therefore possibly to this library) diff --git a/parso/python/tree.py b/parso/python/tree.py index 54d1aa9..a38b70d 100644 --- a/parso/python/tree.py +++ b/parso/python/tree.py @@ -592,6 +592,21 @@ class Function(ClassOrFunc): return scan(self.children) + def iter_raise_stmts(self): + """ + Returns a generator of `raise_stmt`. Includes raise statements inside try-except blocks + """ + def scan(children): + for element in children: + if element.type == 'raise_stmt' \ + or element.type == 'keyword' and element.value == 'raise': + yield element + if element.type in _RETURN_STMT_CONTAINERS: + for e in scan(element.children): + yield e + + return scan(self.children) + def is_generator(self): """ :return bool: Checks if a function is a generator or not. diff --git a/test/test_parser_tree.py b/test/test_parser_tree.py index a39ea74..69cfe76 100644 --- a/test/test_parser_tree.py +++ b/test/test_parser_tree.py @@ -125,6 +125,10 @@ def get_return_stmts(code): return list(parse(code).children[0].iter_return_stmts()) +def get_raise_stmts(code, child): + return list(parse(code).children[child].iter_raise_stmts()) + + def test_yields(each_version): y, = get_yield_exprs('def x(): yield', each_version) assert y.value == 'yield' @@ -149,3 +153,30 @@ def test_returns(): r, = get_return_stmts('def x(): return 1') assert r.type == 'return_stmt' + + +def test_raises(): + code = """ +def single_function(): + raise Exception +def top_function(): + def inner_function(): + raise NotImplementedError() + inner_function() + raise Exception +def top_function_three(): + try: + raise NotImplementedError() + except NotImplementedError: + pass + raise Exception + """ + + r = get_raise_stmts(code, 0) # Lists in a simple Function + assert len(list(r)) == 1 + + r = get_raise_stmts(code, 1) # Doesn't Exceptions list in closures + assert len(list(r)) == 1 + + r = get_raise_stmts(code, 2) # Lists inside try-catch + assert len(list(r)) == 2