From 027e29ec50543679c0d559cc747355bae8947122 Mon Sep 17 00:00:00 2001 From: Eric Masseran Date: Sat, 15 Feb 2025 20:12:53 +0100 Subject: [PATCH] Support base class and metaclass mode --- jedi/inference/value/klass.py | 122 +++++++++++++++++++++++++- jedi/plugins/stdlib.py | 58 +----------- test/test_inference/test_signature.py | 82 ++++++++++------- 3 files changed, 173 insertions(+), 89 deletions(-) diff --git a/jedi/inference/value/klass.py b/jedi/inference/value/klass.py index d4074f36..1bd064b2 100644 --- a/jedi/inference/value/klass.py +++ b/jedi/inference/value/klass.py @@ -47,11 +47,15 @@ from jedi.inference.filters import ParserTreeFilter from jedi.inference.names import TreeNameDefinition, ValueName from jedi.inference.arguments import unpack_arglist, ValuesArguments from jedi.inference.base_value import ValueSet, iterator_to_value_set, \ - NO_VALUES + NO_VALUES, ValueWrapper from jedi.inference.context import ClassContext from jedi.inference.value.function import FunctionAndClassBase +from jedi.inference.value.decorator import Decoratee from jedi.inference.gradual.generics import LazyGenericManager, TupleGenericManager from jedi.plugins import plugin_manager +from inspect import Parameter +from jedi.inference.names import BaseTreeParamName +from jedi.inference.signature import AbstractSignature class ClassName(TreeNameDefinition): @@ -129,6 +133,32 @@ class ClassFilter(ParserTreeFilter): return [name for name in names if self._access_possible(name)] +def get_dataclass_param_names(cls): + """ + ``cls`` is a :class:`ClassMixin`. + """ + param_names = [] + filter_ = cls.as_context().get_global_filter() + # .values ordering is not guaranteed, at least not in + # Python < 3.6, when dicts where not ordered, which is an + # implementation detail anyway. + for name in sorted(filter_.values(), key=lambda name: name.start_pos): + d = name.tree_name.get_definition() + annassign = d.children[1] + if d.type == 'expr_stmt' and annassign.type == 'annassign': + if len(annassign.children) < 4: + default = None + else: + default = annassign.children[3] + param_names.append(DataclassParamName( + parent_context=cls.parent_context, + tree_name=name.tree_name, + annotation_node=annassign.children[1], + default_node=default, + )) + return param_names + + class ClassMixin: def is_class(self): return True @@ -221,6 +251,53 @@ class ClassMixin: assert x is not None yield x + def _has_dataclass_transform_metaclasses(self) -> bool: + for meta in self.get_metaclasses(): + if ( + # Not sure if necessary + isinstance(meta, DataclassWrapper) + or ( + isinstance(meta, Decoratee) + # Internal leakage :| + and isinstance(meta._wrapped_value, DataclassWrapper) + ) + ): + return True + + return False + + def _get_dataclass_transform_signatures(self): + """ + Returns: A non-empty list if the class is dataclass transformed else an + empty list. + """ + param_names = [] + is_dataclass_transform = False + for cls in reversed(list(self.py__mro__())): + if not is_dataclass_transform and ( + isinstance(cls, DataclassWrapper) + or ( + # Some object like CompiledValues would not be compatible + isinstance(cls, ClassMixin) + and cls._has_dataclass_transform_metaclasses() + ) + ): + is_dataclass_transform = True + # Attributes on the decorated class and its base classes are not + # considered to be fields. + continue + + # All inherited behave like dataclass + if is_dataclass_transform: + param_names.extend( + get_dataclass_param_names(cls) + ) + + if is_dataclass_transform: + return [DataclassSignature(cls, param_names)] + else: + [] + def get_signatures(self): # Since calling staticmethod without a function is illegal, the Jedi # plugin doesn't return anything. Therefore call directly and get what @@ -232,7 +309,12 @@ class ClassMixin: return sigs args = ValuesArguments([]) init_funcs = self.py__call__(args).py__getattribute__('__init__') - return [sig.bind(self) for sig in init_funcs.get_signatures()] + + dataclass_sigs = self._get_dataclass_transform_signatures() + if dataclass_sigs: + return dataclass_sigs + else: + return [sig.bind(self) for sig in init_funcs.get_signatures()] def _as_context(self): return ClassContext(self) @@ -319,6 +401,42 @@ class ClassMixin: return ValueSet({self}) +class DataclassParamName(BaseTreeParamName): + def __init__(self, parent_context, tree_name, annotation_node, default_node): + super().__init__(parent_context, tree_name) + self.annotation_node = annotation_node + self.default_node = default_node + + def get_kind(self): + return Parameter.POSITIONAL_OR_KEYWORD + + def infer(self): + if self.annotation_node is None: + return NO_VALUES + else: + return self.parent_context.infer_node(self.annotation_node) + + +class DataclassSignature(AbstractSignature): + def __init__(self, value, param_names): + super().__init__(value) + self._param_names = param_names + + def get_param_names(self, resolve_stars=False): + return self._param_names + + +class DataclassWrapper(ValueWrapper, ClassMixin): + def get_signatures(self): + param_names = [] + for cls in reversed(list(self.py__mro__())): + if isinstance(cls, DataclassWrapper): + param_names.extend( + get_dataclass_param_names(cls) + ) + return [DataclassSignature(cls, param_names)] + + class ClassValue(ClassMixin, FunctionAndClassBase, metaclass=CachedMetaClass): api_type = 'class' diff --git a/jedi/plugins/stdlib.py b/jedi/plugins/stdlib.py index f864a3f5..d85e50db 100644 --- a/jedi/plugins/stdlib.py +++ b/jedi/plugins/stdlib.py @@ -11,7 +11,6 @@ compiled module that returns the types for C-builtins. """ import parso import os -from inspect import Parameter from jedi import debug from jedi.inference.utils import safe_property @@ -25,15 +24,15 @@ from jedi.inference.value.instance import \ from jedi.inference.base_value import ContextualizedNode, \ NO_VALUES, ValueSet, ValueWrapper, LazyValueWrapper from jedi.inference.value import ClassValue, ModuleValue -from jedi.inference.value.klass import ClassMixin +from jedi.inference.value.klass import DataclassWrapper from jedi.inference.value.function import FunctionMixin from jedi.inference.value import iterable from jedi.inference.lazy_value import LazyTreeValue, LazyKnownValue, \ LazyKnownValues -from jedi.inference.names import ValueName, BaseTreeParamName +from jedi.inference.names import ValueName from jedi.inference.filters import AttributeOverwrite, publish_method, \ ParserTreeFilter, DictFilter -from jedi.inference.signature import AbstractSignature, SignatureWrapper +from jedi.inference.signature import SignatureWrapper # Copied from Python 3.6's stdlib. @@ -599,57 +598,6 @@ def _dataclass(value, arguments, callback): return NO_VALUES -class DataclassWrapper(ValueWrapper, ClassMixin): - def get_signatures(self): - param_names = [] - for cls in reversed(list(self.py__mro__())): - if isinstance(cls, DataclassWrapper): - filter_ = cls.as_context().get_global_filter() - # .values ordering is not guaranteed, at least not in - # Python < 3.6, when dicts where not ordered, which is an - # implementation detail anyway. - for name in sorted(filter_.values(), key=lambda name: name.start_pos): - d = name.tree_name.get_definition() - annassign = d.children[1] - if d.type == 'expr_stmt' and annassign.type == 'annassign': - if len(annassign.children) < 4: - default = None - else: - default = annassign.children[3] - param_names.append(DataclassParamName( - parent_context=cls.parent_context, - tree_name=name.tree_name, - annotation_node=annassign.children[1], - default_node=default, - )) - return [DataclassSignature(cls, param_names)] - - -class DataclassSignature(AbstractSignature): - def __init__(self, value, param_names): - super().__init__(value) - self._param_names = param_names - - def get_param_names(self, resolve_stars=False): - return self._param_names - - -class DataclassParamName(BaseTreeParamName): - def __init__(self, parent_context, tree_name, annotation_node, default_node): - super().__init__(parent_context, tree_name) - self.annotation_node = annotation_node - self.default_node = default_node - - def get_kind(self): - return Parameter.POSITIONAL_OR_KEYWORD - - def infer(self): - if self.annotation_node is None: - return NO_VALUES - else: - return self.parent_context.infer_node(self.annotation_node) - - class ItemGetterCallable(ValueWrapper): def __init__(self, instance, args_value_set): super().__init__(instance) diff --git a/test/test_inference/test_signature.py b/test/test_inference/test_signature.py index cdfbe60e..4c683f9d 100644 --- a/test/test_inference/test_signature.py +++ b/test/test_inference/test_signature.py @@ -354,23 +354,56 @@ def test_dataclass_signature(Script, skip_pre_python37, start, start_params): assert price.name == 'float' +dataclass_transform_cases = [ + # Direct + ['@dataclass_transform\nclass X:', []], + # With params + ['@dataclass_transform(eq=True)\nclass X:', []], + # Subclass + [dedent(''' + class Y(): + y: int + @dataclass_transform + class X(Y):'''), []], + # Both classes + [dedent(''' + @dataclass_transform + class Y(): + y: int + z = 5 + @dataclass_transform + class X(Y):'''), ['y']], + # Base class + [dedent(''' + @dataclass_transform + class Y(): + y: int + z = 5 + class X(Y):'''), []], + # Alternative decorator + [dedent(''' + @dataclass_transform + def create_model(cls): + return cls + @create_model + class X:'''), []], + # Metaclass + [dedent(''' + @dataclass_transform + class ModelMeta(): + y: int + z = 5 + class ModelBase(metaclass=ModelMeta): + t: int + p = 5 + class X(ModelBase):'''), []], +] + +ids = ["direct", "with_params", "sub", "both", "base", "alternative_decorator", "metaclass"] + + @pytest.mark.parametrize( - 'start, start_params', [ - ['@dataclass_transform\nclass X:', []], - ['@dataclass_transform(eq=True)\nclass X:', []], - [dedent(''' - class Y(): - y: int - @dataclass_transform - class X(Y):'''), []], - [dedent(''' - @dataclass_transform - class Y(): - y: int - z = 5 - @dataclass_transform - class X(Y):'''), ['y']], - ] + 'start, start_params', dataclass_transform_cases, ids=ids ) def test_extensions_dataclass_transform_signature(Script, skip_pre_python37, start, start_params): code = dedent(''' @@ -392,22 +425,7 @@ def test_extensions_dataclass_transform_signature(Script, skip_pre_python37, sta @pytest.mark.parametrize( - 'start, start_params', [ - ['@dataclass_transform\nclass X:', []], - ['@dataclass_transform(eq=True)\nclass X:', []], - [dedent(''' - class Y(): - y: int - @dataclass_transform - class X(Y):'''), []], - [dedent(''' - @dataclass_transform - class Y(): - y: int - z = 5 - @dataclass_transform - class X(Y):'''), ['y']], - ] + 'start, start_params', dataclass_transform_cases, ids=ids ) def test_dataclass_transform_signature(Script, skip_pre_python311, start, start_params): code = dedent('''