Add init cases for dataclass

This commit is contained in:
Eric Masseran
2025-03-15 13:07:35 +01:00
parent efc7248175
commit 68c7bf35ce

View File

@@ -318,40 +318,101 @@ def test_wraps_signature(Script, code, signature):
@pytest.mark.parametrize( @pytest.mark.parametrize(
'start, start_params', [ "start, start_params, include_params",
['@dataclass\nclass X:', []], [
['@dataclass(eq=True)\nclass X:', []], ["@dataclass\nclass X:", [], True],
[dedent(''' ["@dataclass(eq=True)\nclass X:", [], True],
[
dedent(
"""
class Y(): class Y():
y: int y: int
@dataclass @dataclass
class X(Y):'''), []], class X(Y):"""
[dedent(''' ),
[],
True,
],
[
dedent(
"""
@dataclass @dataclass
class Y(): class Y():
y: int y: int
z = 5 z = 5
@dataclass @dataclass
class X(Y):'''), ['y']], class X(Y):"""
] ),
["y"],
True,
],
# init=False
[
dedent(
"""
@dataclass_transform(init=False)
class Y():
y: int
z = 5
class X(Y):"""
),
[],
False,
],
# custom init
[
dedent(
"""
@dataclass_transform()
class Y():
y: int
z = 5
class X(Y):
def __init__(self, toto: str):
pass
"""
),
["toto"],
False,
],
],
ids=[
"direct_transformed",
"transformed_with_params",
"subclass_transformed",
"both_transformed",
"init_false",
"custom_init",
],
) )
def test_dataclass_signature(Script, skip_pre_python37, start, start_params): def test_dataclass_signature(
code = dedent(''' Script, skip_pre_python37, start, start_params, include_params
):
code = dedent(
"""
name: str name: str
foo = 3 foo = 3
price: float price: float
quantity: int = 0.0 quantity: int = 0.0
X(''') X("""
)
code = 'from dataclasses import dataclass\n' + start + code code = 'from dataclasses import dataclass\n' + start + code
sig, = Script(code).get_signatures() sig, = Script(code).get_signatures()
assert [p.name for p in sig.params] == start_params + ['name', 'price', 'quantity'] expected_params = (
quantity, = sig.params[-1].infer() [*start_params, "name", "price", "quantity"]
assert quantity.name == 'int' if include_params
price, = sig.params[-2].infer() else [*start_params]
assert price.name == 'float' )
assert [p.name for p in sig.params] == expected_params
if include_params:
quantity, = sig.params[-1].infer()
assert quantity.name == 'int'
price, = sig.params[-2].infer()
assert price.name == 'float'
dataclass_transform_cases = [ dataclass_transform_cases = [