Support init=False for dataclass_transform

This commit is contained in:
Eric Masseran
2025-03-15 16:00:51 +01:00
parent 77cf382a1b
commit 8912a35502
3 changed files with 95 additions and 37 deletions

View File

@@ -49,7 +49,7 @@ from jedi.inference.arguments import unpack_arglist, ValuesArguments
from jedi.inference.base_value import ValueSet, iterator_to_value_set, \
NO_VALUES, ValueWrapper
from jedi.inference.context import ClassContext
from jedi.inference.value.function import FunctionAndClassBase, OverloadedFunctionValue
from jedi.inference.value.function import FunctionAndClassBase, FunctionMixin
from jedi.inference.value.decorator import Decoratee
from jedi.inference.gradual.generics import LazyGenericManager, TupleGenericManager
from jedi.plugins import plugin_manager
@@ -433,13 +433,13 @@ class DataclassSignature(AbstractSignature):
return self._param_names
class DataclassDecorator(OverloadedFunctionValue):
def __init__(self, function, overloaded_functions, arguments):
class DataclassDecorator(ValueWrapper, FunctionMixin):
def __init__(self, function, arguments):
"""
Args:
arguments: The parameters to the dataclass function decorator.
"""
super().__init__(function, overloaded_functions)
super().__init__(function)
self.arguments = arguments
@property

View File

@@ -615,8 +615,7 @@ def _dataclass(value, arguments, callback):
return ValueSet(
[
DataclassDecorator(
value._wrapped_value,
value._overloaded_functions,
value,
arguments=arguments,
)
]

View File

@@ -426,16 +426,19 @@ def test_dataclass_signature(
dataclass_transform_cases = [
# Direct
['@dataclass_transform\nclass X:', []],
# With params
['@dataclass_transform(eq=True)\nclass X:', []],
# Attributes on the decorated class and its base classes
# are not considered to be fields.
# 1/ Declare dataclass transformer
# Base Class
['@dataclass_transform\nclass X:', [], False],
# Base Class with params
['@dataclass_transform(eq=True)\nclass X:', [], False],
# Subclass
[dedent('''
class Y():
y: int
@dataclass_transform
class X(Y):'''), []],
class X(Y):'''), [], False],
# Both classes
[dedent('''
@dataclass_transform
@@ -443,22 +446,23 @@ dataclass_transform_cases = [
y: int
z = 5
@dataclass_transform
class X(Y):'''), ['y']],
# Base class
class X(Y):'''), [], False],
# 2/ Declare dataclass transformed
# Class based
[dedent('''
@dataclass_transform
class Y():
y: int
z = 5
class X(Y):'''), []],
# Alternative decorator
class X(Y):'''), [], True],
# Decorator based
[dedent('''
@dataclass_transform
def create_model(cls):
return cls
@create_model
class X:'''), []],
# Metaclass
class X:'''), [], True],
# Metaclass based
[dedent('''
@dataclass_transform
class ModelMeta():
@@ -467,28 +471,74 @@ dataclass_transform_cases = [
class ModelBase(metaclass=ModelMeta):
t: int
p = 5
class X(ModelBase):'''), []],
class X(ModelBase):'''), [], True],
# 3/ Init tweaks
# init=False
[dedent('''
@dataclass_transform(init=False)
class Y():
y: int
z = 5
class X(Y):'''), [], False],
[dedent('''
@dataclass_transform(eq=True, init=False)
class Y():
y: int
z = 5
class X(Y):'''), [], False],
# custom init
[dedent('''
@dataclass_transform()
class Y():
y: int
z = 5
class X(Y):
def __init__(self, toto: str):
pass
'''), ["toto"], False],
]
ids = ["direct", "with_params", "sub", "both", "base", "alternative_decorator", "metaclass"]
ids = [
"direct_transformer",
"transformer_with_params",
"subclass_transformer",
"both_transformer",
"base_transformed",
"decorator_transformed",
"metaclass_transformed",
"init_false",
"init_false_multiple",
"custom_init",
]
@pytest.mark.parametrize(
'start, start_params', dataclass_transform_cases, ids=ids
'start, start_params, include_params', dataclass_transform_cases, ids=ids
)
def test_extensions_dataclass_transform_signature(Script, skip_pre_python37, start, start_params):
code = dedent('''
def test_extensions_dataclass_transform_signature(
Script, skip_pre_python37, start, start_params, include_params
):
code = dedent(
"""
name: str
foo = 3
price: float
quantity: int = 0.0
X(''')
X("""
)
code = 'from typing_extensions import dataclass_transform\n' + start + code
code = "from typing_extensions import dataclass_transform\n" + start + code
sig, = Script(code).get_signatures()
assert [p.name for p in sig.params] == start_params + ['name', 'price', 'quantity']
(sig,) = Script(code).get_signatures()
expected_params = (
[*start_params, "name", "price", "quantity"]
if include_params
else [*start_params]
)
assert [p.name for p in sig.params] == expected_params
if include_params:
quantity, = sig.params[-1].infer()
assert quantity.name == 'int'
price, = sig.params[-2].infer()
@@ -496,9 +546,11 @@ def test_extensions_dataclass_transform_signature(Script, skip_pre_python37, sta
@pytest.mark.parametrize(
'start, start_params', dataclass_transform_cases, ids=ids
"start, start_params, include_params", dataclass_transform_cases, ids=ids
)
def test_dataclass_transform_signature(Script, skip_pre_python311, start, start_params):
def test_dataclass_transform_signature(
Script, skip_pre_python311, start, start_params, include_params
):
code = dedent('''
name: str
foo = 3
@@ -510,7 +562,14 @@ def test_dataclass_transform_signature(Script, skip_pre_python311, start, start_
code = 'from typing import dataclass_transform\n' + start + code
sig, = Script(code).get_signatures()
assert [p.name for p in sig.params] == start_params + ['name', 'price', 'quantity']
expected_params = (
[*start_params, "name", "price", "quantity"]
if include_params
else [*start_params]
)
assert [p.name for p in sig.params] == expected_params
if include_params:
quantity, = sig.params[-1].infer()
assert quantity.name == 'int'
price, = sig.params[-2].infer()