Support init=False for dataclass

This commit is contained in:
Eric Masseran
2025-03-15 15:53:51 +01:00
parent 70efe2134c
commit 77cf382a1b
3 changed files with 79 additions and 15 deletions

View File

@@ -49,7 +49,7 @@ from jedi.inference.arguments import unpack_arglist, ValuesArguments
from jedi.inference.base_value import ValueSet, iterator_to_value_set, \ from jedi.inference.base_value import ValueSet, iterator_to_value_set, \
NO_VALUES, ValueWrapper NO_VALUES, ValueWrapper
from jedi.inference.context import ClassContext from jedi.inference.context import ClassContext
from jedi.inference.value.function import FunctionAndClassBase from jedi.inference.value.function import FunctionAndClassBase, OverloadedFunctionValue
from jedi.inference.value.decorator import Decoratee from jedi.inference.value.decorator import Decoratee
from jedi.inference.gradual.generics import LazyGenericManager, TupleGenericManager from jedi.inference.gradual.generics import LazyGenericManager, TupleGenericManager
from jedi.plugins import plugin_manager from jedi.plugins import plugin_manager
@@ -433,6 +433,40 @@ class DataclassSignature(AbstractSignature):
return self._param_names return self._param_names
class DataclassDecorator(OverloadedFunctionValue):
def __init__(self, function, overloaded_functions, arguments):
"""
Args:
arguments: The parameters to the dataclass function decorator.
"""
super().__init__(function, overloaded_functions)
self.arguments = arguments
@property
def has_dataclass_init_false(self) -> bool:
"""
Returns:
bool: True if dataclass(init=False)
"""
if not self.arguments.argument_node:
return False
arg_nodes = (
self.arguments.argument_node.children
if self.arguments.argument_node.type == "arglist"
else [self.arguments.argument_node]
)
for arg_node in arg_nodes:
if (
arg_node.type == "argument"
and arg_node.children[0].value == "init"
and arg_node.children[2].value == "False"
):
return True
return False
class DataclassWrapper(ValueWrapper, ClassMixin): class DataclassWrapper(ValueWrapper, ClassMixin):
def get_signatures(self): def get_signatures(self):
param_names = [] param_names = []

View File

@@ -24,7 +24,7 @@ from jedi.inference.value.instance import \
from jedi.inference.base_value import ContextualizedNode, \ from jedi.inference.base_value import ContextualizedNode, \
NO_VALUES, ValueSet, ValueWrapper, LazyValueWrapper NO_VALUES, ValueSet, ValueWrapper, LazyValueWrapper
from jedi.inference.value import ClassValue, ModuleValue from jedi.inference.value import ClassValue, ModuleValue
from jedi.inference.value.klass import DataclassWrapper from jedi.inference.value.klass import DataclassWrapper, DataclassDecorator
from jedi.inference.value.function import FunctionMixin from jedi.inference.value.function import FunctionMixin
from jedi.inference.value import iterable from jedi.inference.value import iterable
from jedi.inference.lazy_value import LazyTreeValue, LazyKnownValue, \ from jedi.inference.lazy_value import LazyTreeValue, LazyKnownValue, \
@@ -590,11 +590,37 @@ def _random_choice(sequences):
def _dataclass(value, arguments, callback): def _dataclass(value, arguments, callback):
"""
dataclass decorator can be called 2 times with different arguments. One to
customize it dataclass(eq=True) and another one with the class to
transform.
"""
for c in _follow_param(value.inference_state, arguments, 0): for c in _follow_param(value.inference_state, arguments, 0):
if c.is_class(): if c.is_class():
return ValueSet([DataclassWrapper(c)]) # Decorate the class
dataclass_init = (
# Customized decorator
not value.has_dataclass_init_false
if isinstance(value, DataclassDecorator)
# Bare dataclass decorator
else True
)
if dataclass_init:
return ValueSet([DataclassWrapper(c)])
else:
return ValueSet([c])
else: else:
return ValueSet([value]) # Decorator customization
return ValueSet(
[
DataclassDecorator(
value._wrapped_value,
value._overloaded_functions,
arguments=arguments,
)
]
)
return NO_VALUES return NO_VALUES

View File

@@ -350,11 +350,17 @@ def test_wraps_signature(Script, code, signature):
[ [
dedent( dedent(
""" """
@dataclass_transform(init=False) @dataclass(init=False)
class Y(): class X:"""
y: int ),
z = 5 [],
class X(Y):""" False,
],
[
dedent(
"""
@dataclass(eq=True, init=False)
class X:"""
), ),
[], [],
False, False,
@@ -363,11 +369,8 @@ def test_wraps_signature(Script, code, signature):
[ [
dedent( dedent(
""" """
@dataclass_transform() @dataclass()
class Y(): class X:
y: int
z = 5
class X(Y):
def __init__(self, toto: str): def __init__(self, toto: str):
pass pass
""" """
@@ -382,6 +385,7 @@ def test_wraps_signature(Script, code, signature):
"subclass_transformed", "subclass_transformed",
"both_transformed", "both_transformed",
"init_false", "init_false",
"init_false_multiple",
"custom_init", "custom_init",
], ],
) )
@@ -418,7 +422,7 @@ def test_dataclass_signature(
quantity, = sig.params[-1].infer() quantity, = sig.params[-1].infer()
assert quantity.name == 'int' assert quantity.name == 'int'
price, = sig.params[-2].infer() price, = sig.params[-2].infer()
assert price.name == 'float' assert price.name == 'object'
dataclass_transform_cases = [ dataclass_transform_cases = [