diff --git a/jedi/inference/value/klass.py b/jedi/inference/value/klass.py index 92d947a6..c8c77703 100644 --- a/jedi/inference/value/klass.py +++ b/jedi/inference/value/klass.py @@ -49,7 +49,7 @@ from jedi.inference.arguments import unpack_arglist, ValuesArguments from jedi.inference.base_value import ValueSet, iterator_to_value_set, \ NO_VALUES, ValueWrapper from jedi.inference.context import ClassContext -from jedi.inference.value.function import FunctionAndClassBase, OverloadedFunctionValue +from jedi.inference.value.function import FunctionAndClassBase, FunctionMixin from jedi.inference.value.decorator import Decoratee from jedi.inference.gradual.generics import LazyGenericManager, TupleGenericManager from jedi.plugins import plugin_manager @@ -433,13 +433,13 @@ class DataclassSignature(AbstractSignature): return self._param_names -class DataclassDecorator(OverloadedFunctionValue): - def __init__(self, function, overloaded_functions, arguments): +class DataclassDecorator(ValueWrapper, FunctionMixin): + def __init__(self, function, arguments): """ Args: arguments: The parameters to the dataclass function decorator. """ - super().__init__(function, overloaded_functions) + super().__init__(function) self.arguments = arguments @property diff --git a/jedi/plugins/stdlib.py b/jedi/plugins/stdlib.py index bcc4fe2b..0ed2f24c 100644 --- a/jedi/plugins/stdlib.py +++ b/jedi/plugins/stdlib.py @@ -615,8 +615,7 @@ def _dataclass(value, arguments, callback): return ValueSet( [ DataclassDecorator( - value._wrapped_value, - value._overloaded_functions, + value, arguments=arguments, ) ] diff --git a/test/test_inference/test_signature.py b/test/test_inference/test_signature.py index 48cfafc5..cf89709b 100644 --- a/test/test_inference/test_signature.py +++ b/test/test_inference/test_signature.py @@ -426,16 +426,19 @@ def test_dataclass_signature( dataclass_transform_cases = [ - # Direct - ['@dataclass_transform\nclass X:', []], - # With params - ['@dataclass_transform(eq=True)\nclass X:', []], + # Attributes on the decorated class and its base classes + # are not considered to be fields. + # 1/ Declare dataclass transformer + # Base Class + ['@dataclass_transform\nclass X:', [], False], + # Base Class with params + ['@dataclass_transform(eq=True)\nclass X:', [], False], # Subclass [dedent(''' class Y(): y: int @dataclass_transform - class X(Y):'''), []], + class X(Y):'''), [], False], # Both classes [dedent(''' @dataclass_transform @@ -443,22 +446,23 @@ dataclass_transform_cases = [ y: int z = 5 @dataclass_transform - class X(Y):'''), ['y']], - # Base class + class X(Y):'''), [], False], + # 2/ Declare dataclass transformed + # Class based [dedent(''' @dataclass_transform class Y(): y: int z = 5 - class X(Y):'''), []], - # Alternative decorator + class X(Y):'''), [], True], + # Decorator based [dedent(''' @dataclass_transform def create_model(cls): return cls @create_model - class X:'''), []], - # Metaclass + class X:'''), [], True], + # Metaclass based [dedent(''' @dataclass_transform class ModelMeta(): @@ -467,38 +471,86 @@ dataclass_transform_cases = [ class ModelBase(metaclass=ModelMeta): t: int p = 5 - class X(ModelBase):'''), []], + class X(ModelBase):'''), [], True], + # 3/ Init tweaks + # init=False + [dedent(''' + @dataclass_transform(init=False) + class Y(): + y: int + z = 5 + class X(Y):'''), [], False], + [dedent(''' + @dataclass_transform(eq=True, init=False) + class Y(): + y: int + z = 5 + class X(Y):'''), [], False], + # custom init + [dedent(''' + @dataclass_transform() + class Y(): + y: int + z = 5 + class X(Y): + def __init__(self, toto: str): + pass + '''), ["toto"], False], ] -ids = ["direct", "with_params", "sub", "both", "base", "alternative_decorator", "metaclass"] +ids = [ + "direct_transformer", + "transformer_with_params", + "subclass_transformer", + "both_transformer", + "base_transformed", + "decorator_transformed", + "metaclass_transformed", + "init_false", + "init_false_multiple", + "custom_init", +] @pytest.mark.parametrize( - 'start, start_params', dataclass_transform_cases, ids=ids + 'start, start_params, include_params', dataclass_transform_cases, ids=ids ) -def test_extensions_dataclass_transform_signature(Script, skip_pre_python37, start, start_params): - code = dedent(''' +def test_extensions_dataclass_transform_signature( + Script, skip_pre_python37, start, start_params, include_params +): + code = dedent( + """ name: str foo = 3 price: float quantity: int = 0.0 - X(''') + X(""" + ) - code = 'from typing_extensions import dataclass_transform\n' + start + code + code = "from typing_extensions import dataclass_transform\n" + start + code - sig, = Script(code).get_signatures() - assert [p.name for p in sig.params] == start_params + ['name', 'price', 'quantity'] - quantity, = sig.params[-1].infer() - assert quantity.name == 'int' - price, = sig.params[-2].infer() - assert price.name == 'float' + (sig,) = Script(code).get_signatures() + expected_params = ( + [*start_params, "name", "price", "quantity"] + if include_params + else [*start_params] + ) + assert [p.name for p in sig.params] == expected_params + + if include_params: + quantity, = sig.params[-1].infer() + assert quantity.name == 'int' + price, = sig.params[-2].infer() + assert price.name == 'float' @pytest.mark.parametrize( - 'start, start_params', dataclass_transform_cases, ids=ids + "start, start_params, include_params", dataclass_transform_cases, ids=ids ) -def test_dataclass_transform_signature(Script, skip_pre_python311, start, start_params): +def test_dataclass_transform_signature( + Script, skip_pre_python311, start, start_params, include_params +): code = dedent(''' name: str foo = 3 @@ -510,11 +562,18 @@ def test_dataclass_transform_signature(Script, skip_pre_python311, start, start_ code = 'from typing import dataclass_transform\n' + start + code sig, = Script(code).get_signatures() - assert [p.name for p in sig.params] == start_params + ['name', 'price', 'quantity'] - quantity, = sig.params[-1].infer() - assert quantity.name == 'int' - price, = sig.params[-2].infer() - assert price.name == 'float' + expected_params = ( + [*start_params, "name", "price", "quantity"] + if include_params + else [*start_params] + ) + assert [p.name for p in sig.params] == expected_params + + if include_params: + quantity, = sig.params[-1].infer() + assert quantity.name == 'int' + price, = sig.params[-2].infer() + assert price.name == 'float' @pytest.mark.parametrize(