diff --git a/jedi/inference/value/klass.py b/jedi/inference/value/klass.py index ee588391..92d947a6 100644 --- a/jedi/inference/value/klass.py +++ b/jedi/inference/value/klass.py @@ -49,7 +49,7 @@ from jedi.inference.arguments import unpack_arglist, ValuesArguments from jedi.inference.base_value import ValueSet, iterator_to_value_set, \ NO_VALUES, ValueWrapper from jedi.inference.context import ClassContext -from jedi.inference.value.function import FunctionAndClassBase +from jedi.inference.value.function import FunctionAndClassBase, OverloadedFunctionValue from jedi.inference.value.decorator import Decoratee from jedi.inference.gradual.generics import LazyGenericManager, TupleGenericManager from jedi.plugins import plugin_manager @@ -433,6 +433,40 @@ class DataclassSignature(AbstractSignature): return self._param_names +class DataclassDecorator(OverloadedFunctionValue): + def __init__(self, function, overloaded_functions, arguments): + """ + Args: + arguments: The parameters to the dataclass function decorator. + """ + super().__init__(function, overloaded_functions) + self.arguments = arguments + + @property + def has_dataclass_init_false(self) -> bool: + """ + Returns: + bool: True if dataclass(init=False) + """ + if not self.arguments.argument_node: + return False + + arg_nodes = ( + self.arguments.argument_node.children + if self.arguments.argument_node.type == "arglist" + else [self.arguments.argument_node] + ) + for arg_node in arg_nodes: + if ( + arg_node.type == "argument" + and arg_node.children[0].value == "init" + and arg_node.children[2].value == "False" + ): + return True + + return False + + class DataclassWrapper(ValueWrapper, ClassMixin): def get_signatures(self): param_names = [] diff --git a/jedi/plugins/stdlib.py b/jedi/plugins/stdlib.py index d85e50db..bcc4fe2b 100644 --- a/jedi/plugins/stdlib.py +++ b/jedi/plugins/stdlib.py @@ -24,7 +24,7 @@ from jedi.inference.value.instance import \ from jedi.inference.base_value import ContextualizedNode, \ NO_VALUES, ValueSet, ValueWrapper, LazyValueWrapper from jedi.inference.value import ClassValue, ModuleValue -from jedi.inference.value.klass import DataclassWrapper +from jedi.inference.value.klass import DataclassWrapper, DataclassDecorator from jedi.inference.value.function import FunctionMixin from jedi.inference.value import iterable from jedi.inference.lazy_value import LazyTreeValue, LazyKnownValue, \ @@ -590,11 +590,37 @@ def _random_choice(sequences): def _dataclass(value, arguments, callback): + """ + dataclass decorator can be called 2 times with different arguments. One to + customize it dataclass(eq=True) and another one with the class to + transform. + """ for c in _follow_param(value.inference_state, arguments, 0): if c.is_class(): - return ValueSet([DataclassWrapper(c)]) + # Decorate the class + dataclass_init = ( + # Customized decorator + not value.has_dataclass_init_false + if isinstance(value, DataclassDecorator) + # Bare dataclass decorator + else True + ) + + if dataclass_init: + return ValueSet([DataclassWrapper(c)]) + else: + return ValueSet([c]) else: - return ValueSet([value]) + # Decorator customization + return ValueSet( + [ + DataclassDecorator( + value._wrapped_value, + value._overloaded_functions, + arguments=arguments, + ) + ] + ) return NO_VALUES diff --git a/test/test_inference/test_signature.py b/test/test_inference/test_signature.py index fd65b1b1..48cfafc5 100644 --- a/test/test_inference/test_signature.py +++ b/test/test_inference/test_signature.py @@ -350,11 +350,17 @@ def test_wraps_signature(Script, code, signature): [ dedent( """ - @dataclass_transform(init=False) - class Y(): - y: int - z = 5 - class X(Y):""" + @dataclass(init=False) + class X:""" + ), + [], + False, + ], + [ + dedent( + """ + @dataclass(eq=True, init=False) + class X:""" ), [], False, @@ -363,11 +369,8 @@ def test_wraps_signature(Script, code, signature): [ dedent( """ - @dataclass_transform() - class Y(): - y: int - z = 5 - class X(Y): + @dataclass() + class X: def __init__(self, toto: str): pass """ @@ -382,6 +385,7 @@ def test_wraps_signature(Script, code, signature): "subclass_transformed", "both_transformed", "init_false", + "init_false_multiple", "custom_init", ], ) @@ -418,7 +422,7 @@ def test_dataclass_signature( quantity, = sig.params[-1].infer() assert quantity.name == 'int' price, = sig.params[-2].infer() - assert price.name == 'float' + assert price.name == 'object' dataclass_transform_cases = [