Support base class and metaclass mode

This commit is contained in:
Eric Masseran
2025-02-15 20:12:53 +01:00
parent f9beef0f6b
commit 027e29ec50
3 changed files with 173 additions and 89 deletions

View File

@@ -47,11 +47,15 @@ from jedi.inference.filters import ParserTreeFilter
from jedi.inference.names import TreeNameDefinition, ValueName
from jedi.inference.arguments import unpack_arglist, ValuesArguments
from jedi.inference.base_value import ValueSet, iterator_to_value_set, \
NO_VALUES
NO_VALUES, ValueWrapper
from jedi.inference.context import ClassContext
from jedi.inference.value.function import FunctionAndClassBase
from jedi.inference.value.decorator import Decoratee
from jedi.inference.gradual.generics import LazyGenericManager, TupleGenericManager
from jedi.plugins import plugin_manager
from inspect import Parameter
from jedi.inference.names import BaseTreeParamName
from jedi.inference.signature import AbstractSignature
class ClassName(TreeNameDefinition):
@@ -129,6 +133,32 @@ class ClassFilter(ParserTreeFilter):
return [name for name in names if self._access_possible(name)]
def get_dataclass_param_names(cls):
"""
``cls`` is a :class:`ClassMixin`.
"""
param_names = []
filter_ = cls.as_context().get_global_filter()
# .values ordering is not guaranteed, at least not in
# Python < 3.6, when dicts where not ordered, which is an
# implementation detail anyway.
for name in sorted(filter_.values(), key=lambda name: name.start_pos):
d = name.tree_name.get_definition()
annassign = d.children[1]
if d.type == 'expr_stmt' and annassign.type == 'annassign':
if len(annassign.children) < 4:
default = None
else:
default = annassign.children[3]
param_names.append(DataclassParamName(
parent_context=cls.parent_context,
tree_name=name.tree_name,
annotation_node=annassign.children[1],
default_node=default,
))
return param_names
class ClassMixin:
def is_class(self):
return True
@@ -221,6 +251,53 @@ class ClassMixin:
assert x is not None
yield x
def _has_dataclass_transform_metaclasses(self) -> bool:
for meta in self.get_metaclasses():
if (
# Not sure if necessary
isinstance(meta, DataclassWrapper)
or (
isinstance(meta, Decoratee)
# Internal leakage :|
and isinstance(meta._wrapped_value, DataclassWrapper)
)
):
return True
return False
def _get_dataclass_transform_signatures(self):
"""
Returns: A non-empty list if the class is dataclass transformed else an
empty list.
"""
param_names = []
is_dataclass_transform = False
for cls in reversed(list(self.py__mro__())):
if not is_dataclass_transform and (
isinstance(cls, DataclassWrapper)
or (
# Some object like CompiledValues would not be compatible
isinstance(cls, ClassMixin)
and cls._has_dataclass_transform_metaclasses()
)
):
is_dataclass_transform = True
# Attributes on the decorated class and its base classes are not
# considered to be fields.
continue
# All inherited behave like dataclass
if is_dataclass_transform:
param_names.extend(
get_dataclass_param_names(cls)
)
if is_dataclass_transform:
return [DataclassSignature(cls, param_names)]
else:
[]
def get_signatures(self):
# Since calling staticmethod without a function is illegal, the Jedi
# plugin doesn't return anything. Therefore call directly and get what
@@ -232,6 +309,11 @@ class ClassMixin:
return sigs
args = ValuesArguments([])
init_funcs = self.py__call__(args).py__getattribute__('__init__')
dataclass_sigs = self._get_dataclass_transform_signatures()
if dataclass_sigs:
return dataclass_sigs
else:
return [sig.bind(self) for sig in init_funcs.get_signatures()]
def _as_context(self):
@@ -319,6 +401,42 @@ class ClassMixin:
return ValueSet({self})
class DataclassParamName(BaseTreeParamName):
def __init__(self, parent_context, tree_name, annotation_node, default_node):
super().__init__(parent_context, tree_name)
self.annotation_node = annotation_node
self.default_node = default_node
def get_kind(self):
return Parameter.POSITIONAL_OR_KEYWORD
def infer(self):
if self.annotation_node is None:
return NO_VALUES
else:
return self.parent_context.infer_node(self.annotation_node)
class DataclassSignature(AbstractSignature):
def __init__(self, value, param_names):
super().__init__(value)
self._param_names = param_names
def get_param_names(self, resolve_stars=False):
return self._param_names
class DataclassWrapper(ValueWrapper, ClassMixin):
def get_signatures(self):
param_names = []
for cls in reversed(list(self.py__mro__())):
if isinstance(cls, DataclassWrapper):
param_names.extend(
get_dataclass_param_names(cls)
)
return [DataclassSignature(cls, param_names)]
class ClassValue(ClassMixin, FunctionAndClassBase, metaclass=CachedMetaClass):
api_type = 'class'

View File

@@ -11,7 +11,6 @@ compiled module that returns the types for C-builtins.
"""
import parso
import os
from inspect import Parameter
from jedi import debug
from jedi.inference.utils import safe_property
@@ -25,15 +24,15 @@ from jedi.inference.value.instance import \
from jedi.inference.base_value import ContextualizedNode, \
NO_VALUES, ValueSet, ValueWrapper, LazyValueWrapper
from jedi.inference.value import ClassValue, ModuleValue
from jedi.inference.value.klass import ClassMixin
from jedi.inference.value.klass import DataclassWrapper
from jedi.inference.value.function import FunctionMixin
from jedi.inference.value import iterable
from jedi.inference.lazy_value import LazyTreeValue, LazyKnownValue, \
LazyKnownValues
from jedi.inference.names import ValueName, BaseTreeParamName
from jedi.inference.names import ValueName
from jedi.inference.filters import AttributeOverwrite, publish_method, \
ParserTreeFilter, DictFilter
from jedi.inference.signature import AbstractSignature, SignatureWrapper
from jedi.inference.signature import SignatureWrapper
# Copied from Python 3.6's stdlib.
@@ -599,57 +598,6 @@ def _dataclass(value, arguments, callback):
return NO_VALUES
class DataclassWrapper(ValueWrapper, ClassMixin):
def get_signatures(self):
param_names = []
for cls in reversed(list(self.py__mro__())):
if isinstance(cls, DataclassWrapper):
filter_ = cls.as_context().get_global_filter()
# .values ordering is not guaranteed, at least not in
# Python < 3.6, when dicts where not ordered, which is an
# implementation detail anyway.
for name in sorted(filter_.values(), key=lambda name: name.start_pos):
d = name.tree_name.get_definition()
annassign = d.children[1]
if d.type == 'expr_stmt' and annassign.type == 'annassign':
if len(annassign.children) < 4:
default = None
else:
default = annassign.children[3]
param_names.append(DataclassParamName(
parent_context=cls.parent_context,
tree_name=name.tree_name,
annotation_node=annassign.children[1],
default_node=default,
))
return [DataclassSignature(cls, param_names)]
class DataclassSignature(AbstractSignature):
def __init__(self, value, param_names):
super().__init__(value)
self._param_names = param_names
def get_param_names(self, resolve_stars=False):
return self._param_names
class DataclassParamName(BaseTreeParamName):
def __init__(self, parent_context, tree_name, annotation_node, default_node):
super().__init__(parent_context, tree_name)
self.annotation_node = annotation_node
self.default_node = default_node
def get_kind(self):
return Parameter.POSITIONAL_OR_KEYWORD
def infer(self):
if self.annotation_node is None:
return NO_VALUES
else:
return self.parent_context.infer_node(self.annotation_node)
class ItemGetterCallable(ValueWrapper):
def __init__(self, instance, args_value_set):
super().__init__(instance)

View File

@@ -354,15 +354,18 @@ def test_dataclass_signature(Script, skip_pre_python37, start, start_params):
assert price.name == 'float'
@pytest.mark.parametrize(
'start, start_params', [
dataclass_transform_cases = [
# Direct
['@dataclass_transform\nclass X:', []],
# With params
['@dataclass_transform(eq=True)\nclass X:', []],
# Subclass
[dedent('''
class Y():
y: int
@dataclass_transform
class X(Y):'''), []],
# Both classes
[dedent('''
@dataclass_transform
class Y():
@@ -370,7 +373,37 @@ def test_dataclass_signature(Script, skip_pre_python37, start, start_params):
z = 5
@dataclass_transform
class X(Y):'''), ['y']],
]
# Base class
[dedent('''
@dataclass_transform
class Y():
y: int
z = 5
class X(Y):'''), []],
# Alternative decorator
[dedent('''
@dataclass_transform
def create_model(cls):
return cls
@create_model
class X:'''), []],
# Metaclass
[dedent('''
@dataclass_transform
class ModelMeta():
y: int
z = 5
class ModelBase(metaclass=ModelMeta):
t: int
p = 5
class X(ModelBase):'''), []],
]
ids = ["direct", "with_params", "sub", "both", "base", "alternative_decorator", "metaclass"]
@pytest.mark.parametrize(
'start, start_params', dataclass_transform_cases, ids=ids
)
def test_extensions_dataclass_transform_signature(Script, skip_pre_python37, start, start_params):
code = dedent('''
@@ -392,22 +425,7 @@ def test_extensions_dataclass_transform_signature(Script, skip_pre_python37, sta
@pytest.mark.parametrize(
'start, start_params', [
['@dataclass_transform\nclass X:', []],
['@dataclass_transform(eq=True)\nclass X:', []],
[dedent('''
class Y():
y: int
@dataclass_transform
class X(Y):'''), []],
[dedent('''
@dataclass_transform
class Y():
y: int
z = 5
@dataclass_transform
class X(Y):'''), ['y']],
]
'start, start_params', dataclass_transform_cases, ids=ids
)
def test_dataclass_transform_signature(Script, skip_pre_python311, start, start_params):
code = dedent('''