Support init=False for dataclass_transform

This commit is contained in:
Eric Masseran
2025-03-15 16:00:51 +01:00
parent 77cf382a1b
commit 8912a35502
3 changed files with 95 additions and 37 deletions

View File

@@ -49,7 +49,7 @@ from jedi.inference.arguments import unpack_arglist, ValuesArguments
from jedi.inference.base_value import ValueSet, iterator_to_value_set, \ from jedi.inference.base_value import ValueSet, iterator_to_value_set, \
NO_VALUES, ValueWrapper NO_VALUES, ValueWrapper
from jedi.inference.context import ClassContext from jedi.inference.context import ClassContext
from jedi.inference.value.function import FunctionAndClassBase, OverloadedFunctionValue from jedi.inference.value.function import FunctionAndClassBase, FunctionMixin
from jedi.inference.value.decorator import Decoratee from jedi.inference.value.decorator import Decoratee
from jedi.inference.gradual.generics import LazyGenericManager, TupleGenericManager from jedi.inference.gradual.generics import LazyGenericManager, TupleGenericManager
from jedi.plugins import plugin_manager from jedi.plugins import plugin_manager
@@ -433,13 +433,13 @@ class DataclassSignature(AbstractSignature):
return self._param_names return self._param_names
class DataclassDecorator(OverloadedFunctionValue): class DataclassDecorator(ValueWrapper, FunctionMixin):
def __init__(self, function, overloaded_functions, arguments): def __init__(self, function, arguments):
""" """
Args: Args:
arguments: The parameters to the dataclass function decorator. arguments: The parameters to the dataclass function decorator.
""" """
super().__init__(function, overloaded_functions) super().__init__(function)
self.arguments = arguments self.arguments = arguments
@property @property

View File

@@ -615,8 +615,7 @@ def _dataclass(value, arguments, callback):
return ValueSet( return ValueSet(
[ [
DataclassDecorator( DataclassDecorator(
value._wrapped_value, value,
value._overloaded_functions,
arguments=arguments, arguments=arguments,
) )
] ]

View File

@@ -426,16 +426,19 @@ def test_dataclass_signature(
dataclass_transform_cases = [ dataclass_transform_cases = [
# Direct # Attributes on the decorated class and its base classes
['@dataclass_transform\nclass X:', []], # are not considered to be fields.
# With params # 1/ Declare dataclass transformer
['@dataclass_transform(eq=True)\nclass X:', []], # Base Class
['@dataclass_transform\nclass X:', [], False],
# Base Class with params
['@dataclass_transform(eq=True)\nclass X:', [], False],
# Subclass # Subclass
[dedent(''' [dedent('''
class Y(): class Y():
y: int y: int
@dataclass_transform @dataclass_transform
class X(Y):'''), []], class X(Y):'''), [], False],
# Both classes # Both classes
[dedent(''' [dedent('''
@dataclass_transform @dataclass_transform
@@ -443,22 +446,23 @@ dataclass_transform_cases = [
y: int y: int
z = 5 z = 5
@dataclass_transform @dataclass_transform
class X(Y):'''), ['y']], class X(Y):'''), [], False],
# Base class # 2/ Declare dataclass transformed
# Class based
[dedent(''' [dedent('''
@dataclass_transform @dataclass_transform
class Y(): class Y():
y: int y: int
z = 5 z = 5
class X(Y):'''), []], class X(Y):'''), [], True],
# Alternative decorator # Decorator based
[dedent(''' [dedent('''
@dataclass_transform @dataclass_transform
def create_model(cls): def create_model(cls):
return cls return cls
@create_model @create_model
class X:'''), []], class X:'''), [], True],
# Metaclass # Metaclass based
[dedent(''' [dedent('''
@dataclass_transform @dataclass_transform
class ModelMeta(): class ModelMeta():
@@ -467,38 +471,86 @@ dataclass_transform_cases = [
class ModelBase(metaclass=ModelMeta): class ModelBase(metaclass=ModelMeta):
t: int t: int
p = 5 p = 5
class X(ModelBase):'''), []], class X(ModelBase):'''), [], True],
# 3/ Init tweaks
# init=False
[dedent('''
@dataclass_transform(init=False)
class Y():
y: int
z = 5
class X(Y):'''), [], False],
[dedent('''
@dataclass_transform(eq=True, init=False)
class Y():
y: int
z = 5
class X(Y):'''), [], False],
# custom init
[dedent('''
@dataclass_transform()
class Y():
y: int
z = 5
class X(Y):
def __init__(self, toto: str):
pass
'''), ["toto"], False],
] ]
ids = ["direct", "with_params", "sub", "both", "base", "alternative_decorator", "metaclass"] ids = [
"direct_transformer",
"transformer_with_params",
"subclass_transformer",
"both_transformer",
"base_transformed",
"decorator_transformed",
"metaclass_transformed",
"init_false",
"init_false_multiple",
"custom_init",
]
@pytest.mark.parametrize( @pytest.mark.parametrize(
'start, start_params', dataclass_transform_cases, ids=ids 'start, start_params, include_params', dataclass_transform_cases, ids=ids
) )
def test_extensions_dataclass_transform_signature(Script, skip_pre_python37, start, start_params): def test_extensions_dataclass_transform_signature(
code = dedent(''' Script, skip_pre_python37, start, start_params, include_params
):
code = dedent(
"""
name: str name: str
foo = 3 foo = 3
price: float price: float
quantity: int = 0.0 quantity: int = 0.0
X(''') X("""
)
code = 'from typing_extensions import dataclass_transform\n' + start + code code = "from typing_extensions import dataclass_transform\n" + start + code
sig, = Script(code).get_signatures() (sig,) = Script(code).get_signatures()
assert [p.name for p in sig.params] == start_params + ['name', 'price', 'quantity'] expected_params = (
quantity, = sig.params[-1].infer() [*start_params, "name", "price", "quantity"]
assert quantity.name == 'int' if include_params
price, = sig.params[-2].infer() else [*start_params]
assert price.name == 'float' )
assert [p.name for p in sig.params] == expected_params
if include_params:
quantity, = sig.params[-1].infer()
assert quantity.name == 'int'
price, = sig.params[-2].infer()
assert price.name == 'float'
@pytest.mark.parametrize( @pytest.mark.parametrize(
'start, start_params', dataclass_transform_cases, ids=ids "start, start_params, include_params", dataclass_transform_cases, ids=ids
) )
def test_dataclass_transform_signature(Script, skip_pre_python311, start, start_params): def test_dataclass_transform_signature(
Script, skip_pre_python311, start, start_params, include_params
):
code = dedent(''' code = dedent('''
name: str name: str
foo = 3 foo = 3
@@ -510,11 +562,18 @@ def test_dataclass_transform_signature(Script, skip_pre_python311, start, start_
code = 'from typing import dataclass_transform\n' + start + code code = 'from typing import dataclass_transform\n' + start + code
sig, = Script(code).get_signatures() sig, = Script(code).get_signatures()
assert [p.name for p in sig.params] == start_params + ['name', 'price', 'quantity'] expected_params = (
quantity, = sig.params[-1].infer() [*start_params, "name", "price", "quantity"]
assert quantity.name == 'int' if include_params
price, = sig.params[-2].infer() else [*start_params]
assert price.name == 'float' )
assert [p.name for p in sig.params] == expected_params
if include_params:
quantity, = sig.params[-1].infer()
assert quantity.name == 'int'
price, = sig.params[-2].infer()
assert price.name == 'float'
@pytest.mark.parametrize( @pytest.mark.parametrize(