Fix some default/annotation stuff.

This commit is contained in:
Dave Halter
2017-07-14 02:20:26 +02:00
parent 9d10ceeff1
commit d32a84e181
2 changed files with 44 additions and 2 deletions

View File

@@ -433,7 +433,10 @@ def _create_params(parent, argslist_list):
elif first == '*':
return [first]
else: # argslist is a `typedargslist` or a `varargslist`.
children = first.children
if first.type == 'tfpdef':
children = [first]
else:
children = first.children
new_children = []
start = 0
# Start with offset 1, because the end is higher.
@@ -958,8 +961,10 @@ class Param(PythonBaseNode):
The default is the test node that appears after the `=`. Is `None` in
case no default is present.
"""
has_comma = self.children[-1] == ','
try:
return self.children[int(self.children[0] in ('*', '**')) + 2]
if self.children[-2 - int(has_comma)] == '=':
return self.children[-1 - int(has_comma)]
except IndexError:
return None

View File

@@ -69,3 +69,40 @@ def test_end_pos_line(each_version):
for i, simple_stmt in enumerate(module.children[:-1]):
expr_stmt = simple_stmt.children[0]
assert expr_stmt.end_pos == (i + 1, i + 3)
def test_default_param(each_version):
func = parse('def x(foo=42): pass', version=each_version).children[0]
param, = func.params
assert param.default.value == '42'
assert param.annotation is None
assert not param.star_count
def test_annotation_param(each_py3_version):
func = parse('def x(foo: 3): pass', version=each_py3_version).children[0]
param, = func.params
assert param.default is None
assert param.annotation.value == '3'
assert not param.star_count
def test_annotation_params(each_py3_version):
func = parse('def x(foo: 3, bar: 4): pass', version=each_py3_version).children[0]
param1, param2 = func.params
assert param1.default is None
assert param1.annotation.value == '3'
assert not param1.star_count
assert param2.default is None
assert param2.annotation.value == '4'
assert not param2.star_count
def test_default_and_annotation_param(each_py3_version):
func = parse('def x(foo:3=42): pass', version=each_py3_version).children[0]
param, = func.params
assert param.default.value == '42'
assert param.annotation.value == '3'
assert not param.star_count