diff --git a/test/test_inference/test_signature.py b/test/test_inference/test_signature.py index 4c683f9d..9b9748c5 100644 --- a/test/test_inference/test_signature.py +++ b/test/test_inference/test_signature.py @@ -318,40 +318,101 @@ def test_wraps_signature(Script, code, signature): @pytest.mark.parametrize( - 'start, start_params', [ - ['@dataclass\nclass X:', []], - ['@dataclass(eq=True)\nclass X:', []], - [dedent(''' + "start, start_params, include_params", + [ + ["@dataclass\nclass X:", [], True], + ["@dataclass(eq=True)\nclass X:", [], True], + [ + dedent( + """ class Y(): y: int @dataclass - class X(Y):'''), []], - [dedent(''' + class X(Y):""" + ), + [], + True, + ], + [ + dedent( + """ @dataclass class Y(): y: int z = 5 @dataclass - class X(Y):'''), ['y']], - ] + class X(Y):""" + ), + ["y"], + True, + ], + # init=False + [ + dedent( + """ + @dataclass_transform(init=False) + class Y(): + y: int + z = 5 + class X(Y):""" + ), + [], + False, + ], + # custom init + [ + dedent( + """ + @dataclass_transform() + class Y(): + y: int + z = 5 + class X(Y): + def __init__(self, toto: str): + pass + """ + ), + ["toto"], + False, + ], + ], + ids=[ + "direct_transformed", + "transformed_with_params", + "subclass_transformed", + "both_transformed", + "init_false", + "custom_init", + ], ) -def test_dataclass_signature(Script, skip_pre_python37, start, start_params): - code = dedent(''' +def test_dataclass_signature( + Script, skip_pre_python37, start, start_params, include_params +): + code = dedent( + """ name: str foo = 3 price: float quantity: int = 0.0 - X(''') + X(""" + ) code = 'from dataclasses import dataclass\n' + start + code sig, = Script(code).get_signatures() - assert [p.name for p in sig.params] == start_params + ['name', 'price', 'quantity'] - quantity, = sig.params[-1].infer() - assert quantity.name == 'int' - price, = sig.params[-2].infer() - assert price.name == 'float' + expected_params = ( + [*start_params, "name", "price", "quantity"] + if include_params + else [*start_params] + ) + assert [p.name for p in sig.params] == expected_params + + if include_params: + quantity, = sig.params[-1].infer() + assert quantity.name == 'int' + price, = sig.params[-2].infer() + assert price.name == 'float' dataclass_transform_cases = [