diff --git a/parso/python/tree.py b/parso/python/tree.py index 3865781..7a2ab12 100644 --- a/parso/python/tree.py +++ b/parso/python/tree.py @@ -433,7 +433,10 @@ def _create_params(parent, argslist_list): elif first == '*': return [first] else: # argslist is a `typedargslist` or a `varargslist`. - children = first.children + if first.type == 'tfpdef': + children = [first] + else: + children = first.children new_children = [] start = 0 # Start with offset 1, because the end is higher. @@ -958,8 +961,10 @@ class Param(PythonBaseNode): The default is the test node that appears after the `=`. Is `None` in case no default is present. """ + has_comma = self.children[-1] == ',' try: - return self.children[int(self.children[0] in ('*', '**')) + 2] + if self.children[-2 - int(has_comma)] == '=': + return self.children[-1 - int(has_comma)] except IndexError: return None diff --git a/test/test_parser_tree.py b/test/test_parser_tree.py index 2bf546f..dd8a933 100644 --- a/test/test_parser_tree.py +++ b/test/test_parser_tree.py @@ -69,3 +69,40 @@ def test_end_pos_line(each_version): for i, simple_stmt in enumerate(module.children[:-1]): expr_stmt = simple_stmt.children[0] assert expr_stmt.end_pos == (i + 1, i + 3) + + +def test_default_param(each_version): + func = parse('def x(foo=42): pass', version=each_version).children[0] + param, = func.params + assert param.default.value == '42' + assert param.annotation is None + assert not param.star_count + + +def test_annotation_param(each_py3_version): + func = parse('def x(foo: 3): pass', version=each_py3_version).children[0] + param, = func.params + assert param.default is None + assert param.annotation.value == '3' + assert not param.star_count + + +def test_annotation_params(each_py3_version): + func = parse('def x(foo: 3, bar: 4): pass', version=each_py3_version).children[0] + param1, param2 = func.params + + assert param1.default is None + assert param1.annotation.value == '3' + assert not param1.star_count + + assert param2.default is None + assert param2.annotation.value == '4' + assert not param2.star_count + + +def test_default_and_annotation_param(each_py3_version): + func = parse('def x(foo:3=42): pass', version=each_py3_version).children[0] + param, = func.params + assert param.default.value == '42' + assert param.annotation.value == '3' + assert not param.star_count