Support base class and metaclass mode

This commit is contained in:
Eric Masseran
2025-02-15 20:12:53 +01:00
parent f9beef0f6b
commit 027e29ec50
3 changed files with 173 additions and 89 deletions

View File

@@ -47,11 +47,15 @@ from jedi.inference.filters import ParserTreeFilter
from jedi.inference.names import TreeNameDefinition, ValueName from jedi.inference.names import TreeNameDefinition, ValueName
from jedi.inference.arguments import unpack_arglist, ValuesArguments from jedi.inference.arguments import unpack_arglist, ValuesArguments
from jedi.inference.base_value import ValueSet, iterator_to_value_set, \ from jedi.inference.base_value import ValueSet, iterator_to_value_set, \
NO_VALUES NO_VALUES, ValueWrapper
from jedi.inference.context import ClassContext from jedi.inference.context import ClassContext
from jedi.inference.value.function import FunctionAndClassBase from jedi.inference.value.function import FunctionAndClassBase
from jedi.inference.value.decorator import Decoratee
from jedi.inference.gradual.generics import LazyGenericManager, TupleGenericManager from jedi.inference.gradual.generics import LazyGenericManager, TupleGenericManager
from jedi.plugins import plugin_manager from jedi.plugins import plugin_manager
from inspect import Parameter
from jedi.inference.names import BaseTreeParamName
from jedi.inference.signature import AbstractSignature
class ClassName(TreeNameDefinition): class ClassName(TreeNameDefinition):
@@ -129,6 +133,32 @@ class ClassFilter(ParserTreeFilter):
return [name for name in names if self._access_possible(name)] return [name for name in names if self._access_possible(name)]
def get_dataclass_param_names(cls):
"""
``cls`` is a :class:`ClassMixin`.
"""
param_names = []
filter_ = cls.as_context().get_global_filter()
# .values ordering is not guaranteed, at least not in
# Python < 3.6, when dicts where not ordered, which is an
# implementation detail anyway.
for name in sorted(filter_.values(), key=lambda name: name.start_pos):
d = name.tree_name.get_definition()
annassign = d.children[1]
if d.type == 'expr_stmt' and annassign.type == 'annassign':
if len(annassign.children) < 4:
default = None
else:
default = annassign.children[3]
param_names.append(DataclassParamName(
parent_context=cls.parent_context,
tree_name=name.tree_name,
annotation_node=annassign.children[1],
default_node=default,
))
return param_names
class ClassMixin: class ClassMixin:
def is_class(self): def is_class(self):
return True return True
@@ -221,6 +251,53 @@ class ClassMixin:
assert x is not None assert x is not None
yield x yield x
def _has_dataclass_transform_metaclasses(self) -> bool:
for meta in self.get_metaclasses():
if (
# Not sure if necessary
isinstance(meta, DataclassWrapper)
or (
isinstance(meta, Decoratee)
# Internal leakage :|
and isinstance(meta._wrapped_value, DataclassWrapper)
)
):
return True
return False
def _get_dataclass_transform_signatures(self):
"""
Returns: A non-empty list if the class is dataclass transformed else an
empty list.
"""
param_names = []
is_dataclass_transform = False
for cls in reversed(list(self.py__mro__())):
if not is_dataclass_transform and (
isinstance(cls, DataclassWrapper)
or (
# Some object like CompiledValues would not be compatible
isinstance(cls, ClassMixin)
and cls._has_dataclass_transform_metaclasses()
)
):
is_dataclass_transform = True
# Attributes on the decorated class and its base classes are not
# considered to be fields.
continue
# All inherited behave like dataclass
if is_dataclass_transform:
param_names.extend(
get_dataclass_param_names(cls)
)
if is_dataclass_transform:
return [DataclassSignature(cls, param_names)]
else:
[]
def get_signatures(self): def get_signatures(self):
# Since calling staticmethod without a function is illegal, the Jedi # Since calling staticmethod without a function is illegal, the Jedi
# plugin doesn't return anything. Therefore call directly and get what # plugin doesn't return anything. Therefore call directly and get what
@@ -232,7 +309,12 @@ class ClassMixin:
return sigs return sigs
args = ValuesArguments([]) args = ValuesArguments([])
init_funcs = self.py__call__(args).py__getattribute__('__init__') init_funcs = self.py__call__(args).py__getattribute__('__init__')
return [sig.bind(self) for sig in init_funcs.get_signatures()]
dataclass_sigs = self._get_dataclass_transform_signatures()
if dataclass_sigs:
return dataclass_sigs
else:
return [sig.bind(self) for sig in init_funcs.get_signatures()]
def _as_context(self): def _as_context(self):
return ClassContext(self) return ClassContext(self)
@@ -319,6 +401,42 @@ class ClassMixin:
return ValueSet({self}) return ValueSet({self})
class DataclassParamName(BaseTreeParamName):
def __init__(self, parent_context, tree_name, annotation_node, default_node):
super().__init__(parent_context, tree_name)
self.annotation_node = annotation_node
self.default_node = default_node
def get_kind(self):
return Parameter.POSITIONAL_OR_KEYWORD
def infer(self):
if self.annotation_node is None:
return NO_VALUES
else:
return self.parent_context.infer_node(self.annotation_node)
class DataclassSignature(AbstractSignature):
def __init__(self, value, param_names):
super().__init__(value)
self._param_names = param_names
def get_param_names(self, resolve_stars=False):
return self._param_names
class DataclassWrapper(ValueWrapper, ClassMixin):
def get_signatures(self):
param_names = []
for cls in reversed(list(self.py__mro__())):
if isinstance(cls, DataclassWrapper):
param_names.extend(
get_dataclass_param_names(cls)
)
return [DataclassSignature(cls, param_names)]
class ClassValue(ClassMixin, FunctionAndClassBase, metaclass=CachedMetaClass): class ClassValue(ClassMixin, FunctionAndClassBase, metaclass=CachedMetaClass):
api_type = 'class' api_type = 'class'

View File

@@ -11,7 +11,6 @@ compiled module that returns the types for C-builtins.
""" """
import parso import parso
import os import os
from inspect import Parameter
from jedi import debug from jedi import debug
from jedi.inference.utils import safe_property from jedi.inference.utils import safe_property
@@ -25,15 +24,15 @@ from jedi.inference.value.instance import \
from jedi.inference.base_value import ContextualizedNode, \ from jedi.inference.base_value import ContextualizedNode, \
NO_VALUES, ValueSet, ValueWrapper, LazyValueWrapper NO_VALUES, ValueSet, ValueWrapper, LazyValueWrapper
from jedi.inference.value import ClassValue, ModuleValue from jedi.inference.value import ClassValue, ModuleValue
from jedi.inference.value.klass import ClassMixin from jedi.inference.value.klass import DataclassWrapper
from jedi.inference.value.function import FunctionMixin from jedi.inference.value.function import FunctionMixin
from jedi.inference.value import iterable from jedi.inference.value import iterable
from jedi.inference.lazy_value import LazyTreeValue, LazyKnownValue, \ from jedi.inference.lazy_value import LazyTreeValue, LazyKnownValue, \
LazyKnownValues LazyKnownValues
from jedi.inference.names import ValueName, BaseTreeParamName from jedi.inference.names import ValueName
from jedi.inference.filters import AttributeOverwrite, publish_method, \ from jedi.inference.filters import AttributeOverwrite, publish_method, \
ParserTreeFilter, DictFilter ParserTreeFilter, DictFilter
from jedi.inference.signature import AbstractSignature, SignatureWrapper from jedi.inference.signature import SignatureWrapper
# Copied from Python 3.6's stdlib. # Copied from Python 3.6's stdlib.
@@ -599,57 +598,6 @@ def _dataclass(value, arguments, callback):
return NO_VALUES return NO_VALUES
class DataclassWrapper(ValueWrapper, ClassMixin):
def get_signatures(self):
param_names = []
for cls in reversed(list(self.py__mro__())):
if isinstance(cls, DataclassWrapper):
filter_ = cls.as_context().get_global_filter()
# .values ordering is not guaranteed, at least not in
# Python < 3.6, when dicts where not ordered, which is an
# implementation detail anyway.
for name in sorted(filter_.values(), key=lambda name: name.start_pos):
d = name.tree_name.get_definition()
annassign = d.children[1]
if d.type == 'expr_stmt' and annassign.type == 'annassign':
if len(annassign.children) < 4:
default = None
else:
default = annassign.children[3]
param_names.append(DataclassParamName(
parent_context=cls.parent_context,
tree_name=name.tree_name,
annotation_node=annassign.children[1],
default_node=default,
))
return [DataclassSignature(cls, param_names)]
class DataclassSignature(AbstractSignature):
def __init__(self, value, param_names):
super().__init__(value)
self._param_names = param_names
def get_param_names(self, resolve_stars=False):
return self._param_names
class DataclassParamName(BaseTreeParamName):
def __init__(self, parent_context, tree_name, annotation_node, default_node):
super().__init__(parent_context, tree_name)
self.annotation_node = annotation_node
self.default_node = default_node
def get_kind(self):
return Parameter.POSITIONAL_OR_KEYWORD
def infer(self):
if self.annotation_node is None:
return NO_VALUES
else:
return self.parent_context.infer_node(self.annotation_node)
class ItemGetterCallable(ValueWrapper): class ItemGetterCallable(ValueWrapper):
def __init__(self, instance, args_value_set): def __init__(self, instance, args_value_set):
super().__init__(instance) super().__init__(instance)

View File

@@ -354,23 +354,56 @@ def test_dataclass_signature(Script, skip_pre_python37, start, start_params):
assert price.name == 'float' assert price.name == 'float'
dataclass_transform_cases = [
# Direct
['@dataclass_transform\nclass X:', []],
# With params
['@dataclass_transform(eq=True)\nclass X:', []],
# Subclass
[dedent('''
class Y():
y: int
@dataclass_transform
class X(Y):'''), []],
# Both classes
[dedent('''
@dataclass_transform
class Y():
y: int
z = 5
@dataclass_transform
class X(Y):'''), ['y']],
# Base class
[dedent('''
@dataclass_transform
class Y():
y: int
z = 5
class X(Y):'''), []],
# Alternative decorator
[dedent('''
@dataclass_transform
def create_model(cls):
return cls
@create_model
class X:'''), []],
# Metaclass
[dedent('''
@dataclass_transform
class ModelMeta():
y: int
z = 5
class ModelBase(metaclass=ModelMeta):
t: int
p = 5
class X(ModelBase):'''), []],
]
ids = ["direct", "with_params", "sub", "both", "base", "alternative_decorator", "metaclass"]
@pytest.mark.parametrize( @pytest.mark.parametrize(
'start, start_params', [ 'start, start_params', dataclass_transform_cases, ids=ids
['@dataclass_transform\nclass X:', []],
['@dataclass_transform(eq=True)\nclass X:', []],
[dedent('''
class Y():
y: int
@dataclass_transform
class X(Y):'''), []],
[dedent('''
@dataclass_transform
class Y():
y: int
z = 5
@dataclass_transform
class X(Y):'''), ['y']],
]
) )
def test_extensions_dataclass_transform_signature(Script, skip_pre_python37, start, start_params): def test_extensions_dataclass_transform_signature(Script, skip_pre_python37, start, start_params):
code = dedent(''' code = dedent('''
@@ -392,22 +425,7 @@ def test_extensions_dataclass_transform_signature(Script, skip_pre_python37, sta
@pytest.mark.parametrize( @pytest.mark.parametrize(
'start, start_params', [ 'start, start_params', dataclass_transform_cases, ids=ids
['@dataclass_transform\nclass X:', []],
['@dataclass_transform(eq=True)\nclass X:', []],
[dedent('''
class Y():
y: int
@dataclass_transform
class X(Y):'''), []],
[dedent('''
@dataclass_transform
class Y():
y: int
z = 5
@dataclass_transform
class X(Y):'''), ['y']],
]
) )
def test_dataclass_transform_signature(Script, skip_pre_python311, start, start_params): def test_dataclass_transform_signature(Script, skip_pre_python311, start, start_params):
code = dedent(''' code = dedent('''